diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 98699d648f..8352f9ec1c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -860,7 +860,7 @@ if(${TINT_BUILD_TESTS}) if(${TINT_BUILD_WGSL_READER} AND ${TINT_BUILD_WGSL_WRITER}) list(APPEND TINT_TEST_SRCS - transform/array_length_from_uniform.cc + transform/array_length_from_uniform_test.cc transform/binding_remapper_test.cc transform/bound_array_accessors_test.cc transform/calculate_array_length_test.cc diff --git a/src/clone_context.cc b/src/clone_context.cc index 16fc90bff0..993431eb03 100644 --- a/src/clone_context.cc +++ b/src/clone_context.cc @@ -57,6 +57,7 @@ Symbol CloneContext::Clone(Symbol s) { void CloneContext::Clone() { dst->AST().Copy(this, &src->AST()); + dst->SetTransformApplied(src->TransformsApplied()); } ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) { diff --git a/src/program.cc b/src/program.cc index 81c24f0b2e..24d2a1f04b 100644 --- a/src/program.cc +++ b/src/program.cc @@ -33,6 +33,7 @@ Program::Program(Program&& program) sem_(std::move(program.sem_)), symbols_(std::move(program.symbols_)), diagnostics_(std::move(program.diagnostics_)), + transforms_applied_(std::move(program.transforms_applied_)), is_valid_(program.is_valid_) { program.AssertNotMoved(); program.moved_ = true; @@ -57,6 +58,7 @@ Program::Program(ProgramBuilder&& builder) { sem_ = std::move(builder.Sem()); symbols_ = std::move(builder.Symbols()); diagnostics_.add(std::move(builder.Diagnostics())); + transforms_applied_ = builder.TransformsApplied(); builder.MarkAsMoved(); if (!is_valid_ && !diagnostics_.contains_errors()) { @@ -80,6 +82,7 @@ Program& Program::operator=(Program&& program) { sem_ = std::move(program.sem_); symbols_ = std::move(program.symbols_); diagnostics_ = std::move(program.diagnostics_); + transforms_applied_ = std::move(program.transforms_applied_); is_valid_ = program.is_valid_; return *this; } diff --git a/src/program.h b/src/program.h index d2c5f827ea..32534faab5 100644 --- a/src/program.h +++ b/src/program.h @@ -16,6 +16,7 @@ #define SRC_PROGRAM_H_ #include +#include #include "src/ast/function.h" #include "src/program_id.h" @@ -125,6 +126,25 @@ class Program { /// information bool IsValid() const; + /// @return the TypeInfo pointers of all transforms that have been applied to + /// this program. + std::unordered_set TransformsApplied() const { + return transforms_applied_; + } + + /// @param transform the TypeInfo of the transform + /// @returns true if the transform with the given TypeInfo was applied to the + /// Program + bool HasTransformApplied(const TypeInfo* transform) const { + return transforms_applied_.count(transform); + } + + /// @returns true if the transform of type `T` was applied. + template + bool HasTransformApplied() const { + return HasTransformApplied(&TypeInfo::Of()); + } + /// Helper for returning the resolved semantic type of the expression `expr`. /// @param expr the AST expression /// @return the resolved semantic type for the expression, or nullptr if the @@ -180,6 +200,7 @@ class Program { sem::Info sem_; SymbolTable symbols_{id_}; diag::List diagnostics_; + std::unordered_set transforms_applied_; bool is_valid_ = false; // Not valid until it is built bool moved_ = false; }; diff --git a/src/program_builder.cc b/src/program_builder.cc index effe5139ac..fb82ec61d6 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -37,7 +37,9 @@ ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs) sem_nodes_(std::move(rhs.sem_nodes_)), ast_(rhs.ast_), sem_(std::move(rhs.sem_)), - symbols_(std::move(rhs.symbols_)) { + symbols_(std::move(rhs.symbols_)), + diagnostics_(std::move(rhs.diagnostics_)), + transforms_applied_(std::move(rhs.transforms_applied_)) { rhs.MarkAsMoved(); } @@ -53,6 +55,8 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) { ast_ = rhs.ast_; sem_ = std::move(rhs.sem_); symbols_ = std::move(rhs.symbols_); + diagnostics_ = std::move(rhs.diagnostics_); + transforms_applied_ = std::move(rhs.transforms_applied_); return *this; } @@ -65,6 +69,7 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) { builder.sem_ = sem::Info::Wrap(program->Sem()); builder.symbols_ = program->Symbols(); builder.diagnostics_ = program->Diagnostics(); + builder.transforms_applied_ = program->TransformsApplied(); return builder; } diff --git a/src/program_builder.h b/src/program_builder.h index bf03e0ae9a..facd53976f 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -16,6 +16,7 @@ #define SRC_PROGRAM_BUILDER_H_ #include +#include #include #include "src/ast/alias.h" @@ -84,13 +85,14 @@ #error "internal tint header being #included from tint.h" #endif -namespace tint { - // Forward declarations +namespace tint { namespace ast { class VariableDeclStatement; } // namespace ast +} // namespace tint +namespace tint { class CloneContext; /// ProgramBuilder is a mutable builder for a Program. @@ -2039,6 +2041,40 @@ class ProgramBuilder { source_ = Source(loc); } + /// Marks that the given transform has been applied to this program. + /// @param transform the transform that has been applied + void SetTransformApplied(const CastableBase* transform) { + transforms_applied_.emplace(&transform->TypeInfo()); + } + + /// Marks that the given transform `T` has been applied to this program. + template + void SetTransformApplied() { + transforms_applied_.emplace(&TypeInfo::Of()); + } + + /// Marks that the transforms with the given TypeInfos have been applied to + /// this program. + /// @param transforms the set of transform TypeInfos that has been applied + void SetTransformApplied( + const std::unordered_set& transforms) { + for (auto* transform : transforms) { + transforms_applied_.emplace(transform); + } + } + + /// @returns true if the transform of type `T` was applied. + template + bool HasTransformApplied() { + return transforms_applied_.count(&TypeInfo::Of()); + } + + /// @return the TypeInfo pointers of all transforms that have been applied to + /// this program. + std::unordered_set TransformsApplied() const { + return transforms_applied_; + } + /// Helper for returning the resolved semantic type of the expression `expr`. /// @note As the Resolver is run when the Program is built, this will only be /// useful for the Resolver itself and tests that use their own Resolver. @@ -2125,6 +2161,7 @@ class ProgramBuilder { sem::Info sem_; SymbolTable symbols_{id_}; diag::List diagnostics_; + std::unordered_set transforms_applied_; /// The source to use when creating AST nodes without providing a Source as /// the first argument. diff --git a/src/transform/array_length_from_uniform.cc b/src/transform/array_length_from_uniform.cc index e7fa0f76b3..9cd1096b1f 100644 --- a/src/transform/array_length_from_uniform.cc +++ b/src/transform/array_length_from_uniform.cc @@ -21,7 +21,10 @@ #include "src/program_builder.h" #include "src/sem/call.h" #include "src/sem/variable.h" +#include "src/transform/inline_pointer_lets.h" +#include "src/transform/simplify.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform); TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result); @@ -31,16 +34,19 @@ namespace transform { ArrayLengthFromUniform::ArrayLengthFromUniform() = default; ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; -Output ArrayLengthFromUniform::Run(const Program* in, const DataMap& data) { - ProgramBuilder out; - CloneContext ctx(&out, in); +void ArrayLengthFromUniform::Run(CloneContext& ctx, + const DataMap& inputs, + DataMap& outputs) { + if (!Requires(ctx)) { + return; + } - auto* cfg = data.Get(); + auto* cfg = inputs.Get(); if (cfg == nullptr) { - out.Diagnostics().add_error( + ctx.dst->Diagnostics().add_error( diag::System::Transform, "missing transform data for ArrayLengthFromUniform"); - return Output(Program(std::move(out))); + return; } auto& sem = ctx.src->Sem(); @@ -149,8 +155,7 @@ Output ArrayLengthFromUniform::Run(const Program* in, const DataMap& data) { ctx.Clone(); - return Output{Program(std::move(out)), - std::make_unique(buffer_size_ubo ? true : false)}; + outputs.Add(buffer_size_ubo ? true : false); } ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) diff --git a/src/transform/array_length_from_uniform.h b/src/transform/array_length_from_uniform.h index 7ccb588c86..a310616759 100644 --- a/src/transform/array_length_from_uniform.h +++ b/src/transform/array_length_from_uniform.h @@ -49,7 +49,8 @@ namespace transform { /// This transform assumes that the `InlinePointerLets` and `Simplify` /// transforms have been run before it so that arguments to the arrayLength /// builtin always have the form `&resource.array`. -class ArrayLengthFromUniform : public Transform { +class ArrayLengthFromUniform + : public Castable { public: /// Constructor ArrayLengthFromUniform(); @@ -91,11 +92,14 @@ class ArrayLengthFromUniform : public Transform { bool const needs_buffer_sizes; }; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/array_length_from_uniform_test.cc b/src/transform/array_length_from_uniform_test.cc index 006e094fc9..808e4a74e5 100644 --- a/src/transform/array_length_from_uniform_test.cc +++ b/src/transform/array_length_from_uniform_test.cc @@ -16,6 +16,8 @@ #include +#include "src/transform/inline_pointer_lets.h" +#include "src/transform/simplify.h" #include "src/transform/test_helper.h" namespace tint { @@ -29,7 +31,31 @@ TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { auto* expect = "error: missing transform data for ArrayLengthFromUniform"; - auto got = Run(src); + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ArrayLengthFromUniformTest, Error_MissingInlinePointerLets) { + auto* src = ""; + + auto* expect = + "error: tint::transform::ArrayLengthFromUniform depends on " + "tint::transform::InlinePointerLets but the dependency was not run"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ArrayLengthFromUniformTest, Error_MissingSimplify) { + auto* src = ""; + + auto* expect = + "error: tint::transform::ArrayLengthFromUniform depends on " + "tint::transform::Simplify but the dependency was not run"; + + auto got = Run(src); EXPECT_EQ(expect, str(got)); } @@ -78,7 +104,8 @@ fn main() { DataMap data; data.Add(std::move(cfg)); - auto got = Run(src, data); + auto got = + Run(src, data); EXPECT_EQ(expect, str(got)); EXPECT_TRUE( @@ -131,7 +158,8 @@ fn main() { DataMap data; data.Add(std::move(cfg)); - auto got = Run(src, data); + auto got = + Run(src, data); EXPECT_EQ(expect, str(got)); EXPECT_TRUE( @@ -203,7 +231,8 @@ fn main() { DataMap data; data.Add(std::move(cfg)); - auto got = Run(src, data); + auto got = + Run(src, data); EXPECT_EQ(expect, str(got)); EXPECT_TRUE( @@ -232,7 +261,8 @@ fn main() { DataMap data; data.Add(std::move(cfg)); - auto got = Run(src, data); + auto got = + Run(src, data); EXPECT_EQ(src, str(got)); EXPECT_FALSE( @@ -273,7 +303,8 @@ fn main() { DataMap data; data.Add(std::move(cfg)); - auto got = Run(src, data); + auto got = + Run(src, data); EXPECT_EQ(expect, str(got)); } diff --git a/src/transform/binding_remapper.cc b/src/transform/binding_remapper.cc index 5873cfe44a..5e6cc20b90 100644 --- a/src/transform/binding_remapper.cc +++ b/src/transform/binding_remapper.cc @@ -22,6 +22,7 @@ #include "src/sem/function.h" #include "src/sem/variable.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper); TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper::Remappings); namespace tint { @@ -40,14 +41,13 @@ BindingRemapper::Remappings::~Remappings() = default; BindingRemapper::BindingRemapper() = default; BindingRemapper::~BindingRemapper() = default; -Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { - ProgramBuilder out; - auto* remappings = datamap.Get(); +void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { + auto* remappings = inputs.Get(); if (!remappings) { - out.Diagnostics().add_error( + ctx.dst->Diagnostics().add_error( diag::System::Transform, "BindingRemapper did not find the remapping data"); - return Output(Program(std::move(out))); + return; } // A set of post-remapped binding points that need to be decorated with a @@ -57,11 +57,11 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { if (remappings->allow_collisions) { // Scan for binding point collisions generated by this transform. // Populate all collisions in the `add_collision_deco` set. - for (auto* func_ast : in->AST().Functions()) { + for (auto* func_ast : ctx.src->AST().Functions()) { if (!func_ast->IsEntryPoint()) { continue; } - auto* func = in->Sem().Get(func_ast); + auto* func = ctx.src->Sem().Get(func_ast); std::unordered_map binding_point_counts; for (auto* var : func->ReferencedModuleVariables()) { if (auto binding_point = var->Declaration()->binding_point()) { @@ -85,9 +85,7 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { } } - CloneContext ctx(&out, in); - - for (auto* var : in->AST().GlobalVariables()) { + for (auto* var : ctx.src->AST().GlobalVariables()) { if (auto binding_point = var->binding_point()) { // The original binding point BindingPoint from{binding_point.group->value(), @@ -102,8 +100,8 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { auto bp_it = remappings->binding_points.find(from); if (bp_it != remappings->binding_points.end()) { BindingPoint to = bp_it->second; - auto* new_group = out.create(to.group); - auto* new_binding = out.create(to.binding); + auto* new_group = ctx.dst->create(to.group); + auto* new_binding = ctx.dst->create(to.binding); ctx.Replace(binding_point.group, new_group); ctx.Replace(binding_point.binding, new_binding); @@ -114,7 +112,7 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { auto ac_it = remappings->access_controls.find(from); if (ac_it != remappings->access_controls.end()) { ast::Access ac = ac_it->second; - auto* ty = in->Sem().Get(var)->Type()->UnwrapRef(); + auto* ty = ctx.src->Sem().Get(var)->Type()->UnwrapRef(); ast::Type* inner_ty = CreateASTTypeFor(&ctx, ty); auto* new_var = ctx.dst->create( ctx.Clone(var->source()), ctx.Clone(var->symbol()), @@ -133,8 +131,8 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { } } } + ctx.Clone(); - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/binding_remapper.h b/src/transform/binding_remapper.h index b271c1f497..5fd2fc8d59 100644 --- a/src/transform/binding_remapper.h +++ b/src/transform/binding_remapper.h @@ -29,7 +29,7 @@ using BindingPoint = sem::BindingPoint; /// BindingRemapper is a transform used to remap resource binding points and /// access controls. -class BindingRemapper : public Transform { +class BindingRemapper : public Castable { public: /// BindingPoints is a map of old binding point to new binding point using BindingPoints = std::unordered_map; @@ -68,11 +68,14 @@ class BindingRemapper : public Transform { BindingRemapper(); ~BindingRemapper() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc index e2b1860405..5ad5020c5b 100644 --- a/src/transform/bound_array_accessors.cc +++ b/src/transform/bound_array_accessors.cc @@ -20,20 +20,20 @@ #include "src/program_builder.h" #include "src/sem/expression.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::BoundArrayAccessors); + namespace tint { namespace transform { BoundArrayAccessors::BoundArrayAccessors() = default; BoundArrayAccessors::~BoundArrayAccessors() = default; -Output BoundArrayAccessors::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); +void BoundArrayAccessors::Run(CloneContext& ctx, const DataMap&, DataMap&) { ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) { return Transform(expr, &ctx); }); + ctx.Clone(); - return Output(Program(std::move(out))); } ast::ArrayAccessorExpression* BoundArrayAccessors::Transform( diff --git a/src/transform/bound_array_accessors.h b/src/transform/bound_array_accessors.h index c2330ad476..87803e65a6 100644 --- a/src/transform/bound_array_accessors.h +++ b/src/transform/bound_array_accessors.h @@ -25,18 +25,21 @@ namespace transform { /// the bounds of the array. Any access before the start of the array will clamp /// to zero and any access past the end of the array will clamp to /// (array length - 1). -class BoundArrayAccessors : public Transform { +class BoundArrayAccessors : public Castable { public: /// Constructor BoundArrayAccessors(); /// Destructor ~BoundArrayAccessors() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; private: ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr, diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc index ee48611cb7..098381800c 100644 --- a/src/transform/calculate_array_length.cc +++ b/src/transform/calculate_array_length.cc @@ -27,6 +27,7 @@ #include "src/utils/get_or_create.h" #include "src/utils/hash.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength); TINT_INSTANTIATE_TYPEINFO( tint::transform::CalculateArrayLength::BufferSizeIntrinsic); @@ -69,10 +70,7 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const { CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default; -Output CalculateArrayLength::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) { auto& sem = ctx.src->Sem(); // get_buffer_size_intrinsic() emits the function decorated with @@ -232,8 +230,6 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) { } ctx.Clone(); - - return Output{Program(std::move(out))}; } } // namespace transform diff --git a/src/transform/calculate_array_length.h b/src/transform/calculate_array_length.h index ba0279a5ab..20030587fe 100644 --- a/src/transform/calculate_array_length.h +++ b/src/transform/calculate_array_length.h @@ -29,7 +29,7 @@ namespace transform { /// CalculateArrayLength is a transform used to replace calls to arrayLength() /// with a value calculated from the size of the storage buffer. -class CalculateArrayLength : public Transform { +class CalculateArrayLength : public Castable { public: /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic /// functions used to obtain the runtime size of a storage buffer. @@ -56,11 +56,14 @@ class CalculateArrayLength : public Transform { /// Destructor ~CalculateArrayLength() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index eaf9ce9b82..34ef071d4a 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -24,6 +24,7 @@ #include "src/sem/struct.h" #include "src/sem/variable.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO); TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config); namespace tint { @@ -62,16 +63,15 @@ bool StructMemberComparator(const ast::StructMember* a, } // namespace -Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap& data) { - ProgramBuilder out; - CloneContext ctx(&out, in); - - auto* cfg = data.Get(); +void CanonicalizeEntryPointIO::Run(CloneContext& ctx, + const DataMap& inputs, + DataMap&) { + auto* cfg = inputs.Get(); if (cfg == nullptr) { - out.Diagnostics().add_error( + ctx.dst->Diagnostics().add_error( diag::System::Transform, "missing transform data for CanonicalizeEntryPointIO"); - return Output(Program(std::move(out))); + return; } // Strip entry point IO decorations from struct declarations. @@ -375,7 +375,6 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap& data) { } ctx.Clone(); - return Output(Program(std::move(out))); } CanonicalizeEntryPointIO::Config::Config(BuiltinStyle builtins, diff --git a/src/transform/canonicalize_entry_point_io.h b/src/transform/canonicalize_entry_point_io.h index a7f3465d49..5d9a47b30a 100644 --- a/src/transform/canonicalize_entry_point_io.h +++ b/src/transform/canonicalize_entry_point_io.h @@ -68,7 +68,8 @@ namespace transform { /// return retval; /// } /// ``` -class CanonicalizeEntryPointIO : public Transform { +class CanonicalizeEntryPointIO + : public Castable { public: /// BuiltinStyle is an enumerator of different ways to emit builtins. enum class BuiltinStyle { @@ -102,11 +103,14 @@ class CanonicalizeEntryPointIO : public Transform { CanonicalizeEntryPointIO(); ~CanonicalizeEntryPointIO() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/decompose_memory_access.cc b/src/transform/decompose_memory_access.cc index 96a373a151..7f81894711 100644 --- a/src/transform/decompose_memory_access.cc +++ b/src/transform/decompose_memory_access.cc @@ -38,6 +38,7 @@ #include "src/utils/get_or_create.h" #include "src/utils/hash.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess); TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic); namespace tint { @@ -790,10 +791,7 @@ DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone( DecomposeMemoryAccess::DecomposeMemoryAccess() = default; DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; -Output DecomposeMemoryAccess::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) { auto& sem = ctx.src->Sem(); State state; @@ -987,7 +985,6 @@ Output DecomposeMemoryAccess::Run(const Program* in, const DataMap&) { } ctx.Clone(); - return Output{Program(std::move(out))}; } } // namespace transform diff --git a/src/transform/decompose_memory_access.h b/src/transform/decompose_memory_access.h index c65a5fe8f6..d1cadad6b6 100644 --- a/src/transform/decompose_memory_access.h +++ b/src/transform/decompose_memory_access.h @@ -30,7 +30,8 @@ namespace transform { /// DecomposeMemoryAccess is a transform used to replace storage and uniform /// buffer accesses with a combination of load, store or atomic functions on /// primitive types. -class DecomposeMemoryAccess : public Transform { +class DecomposeMemoryAccess + : public Castable { public: /// Intrinsic is an InternalDecoration that's used to decorate a stub function /// so that the HLSL transforms this into calls to @@ -103,11 +104,14 @@ class DecomposeMemoryAccess : public Transform { /// Destructor ~DecomposeMemoryAccess() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; struct State; }; diff --git a/src/transform/external_texture_transform.cc b/src/transform/external_texture_transform.cc index 4d48ff7d4c..1ede9b8e47 100644 --- a/src/transform/external_texture_transform.cc +++ b/src/transform/external_texture_transform.cc @@ -18,15 +18,17 @@ #include "src/sem/call.h" #include "src/sem/variable.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::ExternalTextureTransform); + namespace tint { namespace transform { ExternalTextureTransform::ExternalTextureTransform() = default; ExternalTextureTransform::~ExternalTextureTransform() = default; -Output ExternalTextureTransform::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); +void ExternalTextureTransform::Run(CloneContext& ctx, + const DataMap&, + DataMap&) { auto& sem = ctx.src->Sem(); // Within this transform, usages of texture_external are replaced with a @@ -105,7 +107,7 @@ Output ExternalTextureTransform::Run(const Program* in, const DataMap&) { // Scan the AST nodes for external texture declarations. for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* var = node->As()) { - if (Is(var->type())) { + if (::tint::Is(var->type())) { // Replace a single-plane external texture with a 2D, f32 sampled // texture. auto* newType = ctx.dst->ty.sampled_texture(ast::TextureDimension::k2d, @@ -125,7 +127,6 @@ Output ExternalTextureTransform::Run(const Program* in, const DataMap&) { } ctx.Clone(); - return Output{Program(std::move(out))}; } } // namespace transform diff --git a/src/transform/external_texture_transform.h b/src/transform/external_texture_transform.h index 350122bf2b..6bb8091f5f 100644 --- a/src/transform/external_texture_transform.h +++ b/src/transform/external_texture_transform.h @@ -27,18 +27,22 @@ namespace transform { /// This allows us to share SPIR-V/HLSL writer paths for sampled textures /// instead of adding dedicated writer paths for external textures. /// ExternalTextureTransform performs this transformation. -class ExternalTextureTransform : public Transform { +class ExternalTextureTransform + : public Castable { public: /// Constructor ExternalTextureTransform(); /// Destructor ~ExternalTextureTransform() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 4a69c63931..8a28364ae6 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -25,6 +25,7 @@ #include "src/sem/struct.h" #include "src/sem/variable.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset); TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::BindingPoint); TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::Data); @@ -57,18 +58,17 @@ FirstIndexOffset::Data::~Data() = default; FirstIndexOffset::FirstIndexOffset() = default; FirstIndexOffset::~FirstIndexOffset() = default; -Output FirstIndexOffset::Run(const Program* in, const DataMap& data) { +void FirstIndexOffset::Run(CloneContext& ctx, + const DataMap& inputs, + DataMap& outputs) { // Get the uniform buffer binding point uint32_t ub_binding = binding_; uint32_t ub_group = group_; - if (auto* binding_point = data.Get()) { + if (auto* binding_point = inputs.Get()) { ub_binding = binding_point->binding; ub_group = binding_point->group; } - ProgramBuilder out; - CloneContext ctx(&out, in); - // Map of builtin usages std::unordered_map builtin_vars; std::unordered_map builtin_members; @@ -78,7 +78,7 @@ Output FirstIndexOffset::Run(const Program* in, const DataMap& data) { // Traverse the AST scanning for builtin accesses via variables (includes // parameters) or structure member accesses. - for (auto* node : in->ASTNodes().Objects()) { + for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* var = node->As()) { for (ast::Decoration* dec : var->decorations()) { if (auto* builtin_dec = dec->As()) { @@ -173,10 +173,8 @@ Output FirstIndexOffset::Run(const Program* in, const DataMap& data) { ctx.Clone(); - return Output( - Program(std::move(out)), - std::make_unique(has_vertex_index, has_instance_index, - vertex_index_offset, instance_index_offset)); + outputs.Add(has_vertex_index, has_instance_index, vertex_index_offset, + instance_index_offset); } } // namespace transform diff --git a/src/transform/first_index_offset.h b/src/transform/first_index_offset.h index f84e9e3939..2ba27b5c24 100644 --- a/src/transform/first_index_offset.h +++ b/src/transform/first_index_offset.h @@ -55,7 +55,7 @@ namespace transform { /// return vert_idx; /// } /// -class FirstIndexOffset : public Transform { +class FirstIndexOffset : public Castable { public: /// BindingPoint is consumed by the FirstIndexOffset transform. /// BindingPoint specifies the binding point of the first index uniform @@ -112,11 +112,14 @@ class FirstIndexOffset : public Transform { /// Destructor ~FirstIndexOffset() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; private: uint32_t binding_ = 0; diff --git a/src/transform/fold_constants.cc b/src/transform/fold_constants.cc index 3fe3691d64..6fcc7986fb 100644 --- a/src/transform/fold_constants.cc +++ b/src/transform/fold_constants.cc @@ -20,6 +20,8 @@ #include "src/program_builder.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants); + namespace tint { namespace { @@ -318,10 +320,7 @@ FoldConstants::FoldConstants() = default; FoldConstants::~FoldConstants() = default; -Output FoldConstants::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) { ExprToValue expr_to_value; // Visit inner expressions before outer expressions @@ -345,8 +344,6 @@ Output FoldConstants::Run(const Program* in, const DataMap&) { } ctx.Clone(); - - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/fold_constants.h b/src/transform/fold_constants.h index 7e18337529..861a06b09b 100644 --- a/src/transform/fold_constants.h +++ b/src/transform/fold_constants.h @@ -21,7 +21,7 @@ namespace tint { namespace transform { /// FoldConstants transforms the AST by folding constant expressions -class FoldConstants : public Transform { +class FoldConstants : public Castable { public: /// Constructor FoldConstants(); @@ -29,11 +29,14 @@ class FoldConstants : public Transform { /// Destructor ~FoldConstants() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc index 1009f053e4..cca668f5b6 100644 --- a/src/transform/hlsl.cc +++ b/src/transform/hlsl.cc @@ -29,6 +29,8 @@ #include "src/transform/wrap_arrays_in_structs.h" #include "src/transform/zero_init_workgroup_memory.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Hlsl); + namespace tint { namespace transform { @@ -69,6 +71,7 @@ Output Hlsl::Run(const Program* in, const DataMap&) { CloneContext ctx(&builder, &out.program); AddEmptyEntryPoint(ctx); ctx.Clone(); + builder.SetTransformApplied(this); return Output{Program(std::move(builder))}; } diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h index 9cffd7760f..091903e3ea 100644 --- a/src/transform/hlsl.h +++ b/src/transform/hlsl.h @@ -27,7 +27,7 @@ namespace transform { /// Hlsl is a transform used to sanitize a Program for use with the Hlsl writer. /// Passing a non-sanitized Program to the Hlsl writer will result in undefined /// behavior. -class Hlsl : public Transform { +class Hlsl : public Castable { public: /// Constructor Hlsl(); diff --git a/src/transform/inline_pointer_lets.cc b/src/transform/inline_pointer_lets.cc index b82d7307e1..6b18f75b47 100644 --- a/src/transform/inline_pointer_lets.cc +++ b/src/transform/inline_pointer_lets.cc @@ -25,6 +25,8 @@ #include "src/sem/variable.h" #include "src/utils/scoped_assignment.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::InlinePointerLets); + namespace tint { namespace transform { namespace { @@ -79,10 +81,7 @@ InlinePointerLets::InlinePointerLets() = default; InlinePointerLets::~InlinePointerLets() = default; -Output InlinePointerLets::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void InlinePointerLets::Run(CloneContext& ctx, const DataMap&, DataMap&) { // If not null, current_ptr_let is the current PtrLet being operated on. PtrLet* current_ptr_let = nullptr; // A map of the AST `let` variable to the PtrLet @@ -107,7 +106,7 @@ Output InlinePointerLets::Run(const Program* in, const DataMap&) { } } if (auto* ident = expr->As()) { - if (auto* vu = in->Sem().Get(ident)) { + if (auto* vu = ctx.src->Sem().Get(ident)) { auto* var = vu->Variable()->Declaration(); auto it = ptr_lets.find(var); if (it != ptr_lets.end()) { @@ -130,13 +129,13 @@ Output InlinePointerLets::Run(const Program* in, const DataMap&) { // Find all the pointer-typed `let` declarations. // Note that these must be function-scoped, as module-scoped `let`s are not // permitted. - for (auto* node : in->ASTNodes().Objects()) { + for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* let = node->As()) { if (!let->variable()->is_const()) { continue; // Not a `let` declaration. Ignore. } - auto* var = in->Sem().Get(let->variable()); + auto* var = ctx.src->Sem().Get(let->variable()); if (!var->Type()->Is()) { continue; // Not a pointer type. Ignore. } @@ -183,8 +182,6 @@ Output InlinePointerLets::Run(const Program* in, const DataMap&) { } ctx.Clone(); - - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/inline_pointer_lets.h b/src/transform/inline_pointer_lets.h index 75ca347d10..accc056b54 100644 --- a/src/transform/inline_pointer_lets.h +++ b/src/transform/inline_pointer_lets.h @@ -31,7 +31,7 @@ namespace transform { /// Note: InlinePointerLets does not operate on module-scope `let`s, as these /// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants /// `A module-scope let-declared constant must be of atomic-free plain type.` -class InlinePointerLets : public Transform { +class InlinePointerLets : public Castable { public: /// Constructor InlinePointerLets(); @@ -39,11 +39,14 @@ class InlinePointerLets : public Transform { /// Destructor ~InlinePointerLets() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/manager.cc b/src/transform/manager.cc index 10e3f58420..0c1a0f030c 100644 --- a/src/transform/manager.cc +++ b/src/transform/manager.cc @@ -14,6 +14,8 @@ #include "src/transform/manager.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Manager); + namespace tint { namespace transform { diff --git a/src/transform/manager.h b/src/transform/manager.h index 345afb7d17..8d31bb57fc 100644 --- a/src/transform/manager.h +++ b/src/transform/manager.h @@ -28,7 +28,7 @@ namespace transform { /// The inner transforms will execute in the appended order. /// If any inner transform fails the manager will return immediately and /// the error can be retrieved with the Output's diagnostics. -class Manager : public Transform { +class Manager : public Castable { public: /// Constructor Manager(); diff --git a/src/transform/msl.cc b/src/transform/msl.cc index c89a29da12..b1c4605371 100644 --- a/src/transform/msl.cc +++ b/src/transform/msl.cc @@ -36,6 +36,7 @@ #include "src/transform/wrap_arrays_in_structs.h" #include "src/transform/zero_init_workgroup_memory.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl); TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Result); @@ -103,6 +104,8 @@ Output Msl::Run(const Program* in, const DataMap& inputs) { auto result = std::make_unique( out.data.Get()->needs_buffer_sizes); + + builder.SetTransformApplied(this); return Output{Program(std::move(builder)), std::move(result)}; } diff --git a/src/transform/msl.h b/src/transform/msl.h index 001dfe02c4..0fb6cb58f0 100644 --- a/src/transform/msl.h +++ b/src/transform/msl.h @@ -23,7 +23,7 @@ namespace transform { /// Msl is a transform used to sanitize a Program for use with the Msl writer. /// Passing a non-sanitized Program to the Msl writer will result in undefined /// behavior. -class Msl : public Transform { +class Msl : public Castable { public: /// The default buffer slot to use for the storage buffer size buffer. const uint32_t kDefaultBufferSizeUniformIndex = 30; diff --git a/src/transform/pad_array_elements.cc b/src/transform/pad_array_elements.cc index 24268f0bc5..a6e335123e 100644 --- a/src/transform/pad_array_elements.cc +++ b/src/transform/pad_array_elements.cc @@ -22,6 +22,8 @@ #include "src/sem/expression.h" #include "src/utils/get_or_create.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements); + namespace tint { namespace transform { namespace { @@ -89,10 +91,7 @@ PadArrayElements::PadArrayElements() = default; PadArrayElements::~PadArrayElements() = default; -Output PadArrayElements::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) { auto& sem = ctx.src->Sem(); std::unordered_map padded_arrays; @@ -149,8 +148,6 @@ Output PadArrayElements::Run(const Program* in, const DataMap&) { }); ctx.Clone(); - - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/pad_array_elements.h b/src/transform/pad_array_elements.h index db9f21553b..cbb3e147ca 100644 --- a/src/transform/pad_array_elements.h +++ b/src/transform/pad_array_elements.h @@ -30,7 +30,7 @@ namespace transform { /// structure element type. /// This transform helps with backends that cannot directly return arrays or use /// them as parameters. -class PadArrayElements : public Transform { +class PadArrayElements : public Castable { public: /// Constructor PadArrayElements(); @@ -38,11 +38,14 @@ class PadArrayElements : public Transform { /// Destructor ~PadArrayElements() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/promote_initializers_to_const_var.cc b/src/transform/promote_initializers_to_const_var.cc index d0724d370e..87e7976075 100644 --- a/src/transform/promote_initializers_to_const_var.cc +++ b/src/transform/promote_initializers_to_const_var.cc @@ -21,6 +21,8 @@ #include "src/sem/expression.h" #include "src/sem/statement.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar); + namespace tint { namespace transform { @@ -28,10 +30,9 @@ PromoteInitializersToConstVar::PromoteInitializersToConstVar() = default; PromoteInitializersToConstVar::~PromoteInitializersToConstVar() = default; -Output PromoteInitializersToConstVar::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void PromoteInitializersToConstVar::Run(CloneContext& ctx, + const DataMap&, + DataMap&) { // Scan the AST nodes for array and structure initializers which // need to be promoted to their own constant declaration. @@ -100,8 +101,6 @@ Output PromoteInitializersToConstVar::Run(const Program* in, const DataMap&) { } ctx.Clone(); - - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/promote_initializers_to_const_var.h b/src/transform/promote_initializers_to_const_var.h index 6162363c8c..59adfb736b 100644 --- a/src/transform/promote_initializers_to_const_var.h +++ b/src/transform/promote_initializers_to_const_var.h @@ -23,7 +23,8 @@ namespace transform { /// A transform that hoists the array and structure initializers to a constant /// variable, declared just before the statement of usage. See /// crbug.com/tint/406 for more details. -class PromoteInitializersToConstVar : public Transform { +class PromoteInitializersToConstVar + : public Castable { public: /// Constructor PromoteInitializersToConstVar(); @@ -31,11 +32,14 @@ class PromoteInitializersToConstVar : public Transform { /// Destructor ~PromoteInitializersToConstVar() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/renamer.cc b/src/transform/renamer.cc index c39171fc4b..b6f6b787dc 100644 --- a/src/transform/renamer.cc +++ b/src/transform/renamer.cc @@ -22,7 +22,9 @@ #include "src/sem/call.h" #include "src/sem/member_accessor_expression.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer); TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer::Data); +TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer::Config); namespace tint { namespace transform { @@ -835,18 +837,18 @@ const char* kReservedKeywordsMSL[] = {"access", } // namespace Renamer::Data::Data(Remappings&& r) : remappings(std::move(r)) {} - Renamer::Data::Data(const Data&) = default; - Renamer::Data::~Data() = default; -Renamer::Renamer() : cfg_{} {} - -Renamer::Renamer(const Config& config) : cfg_(config) {} +Renamer::Config::Config(Target t) : target(t) {} +Renamer::Config::Config(const Config&) = default; +Renamer::Config::~Config() = default; +Renamer::Renamer() : deprecated_cfg_(Target::kAll) {} +Renamer::Renamer(const Config& config) : deprecated_cfg_(config) {} Renamer::~Renamer() = default; -Output Renamer::Run(const Program* in, const DataMap&) { +Output Renamer::Run(const Program* in, const DataMap& inputs) { ProgramBuilder out; // Disable auto-cloning of symbols, since we want to rename them. CloneContext ctx(&out, in, false); @@ -879,9 +881,14 @@ Output Renamer::Run(const Program* in, const DataMap&) { Data::Remappings remappings; + auto* cfg = inputs.Get(); + if (!cfg) { + cfg = &deprecated_cfg_; + } + ctx.ReplaceAll([&](Symbol sym_in) { auto name_in = ctx.src->Symbols().NameFor(sym_in); - switch (cfg_.target) { + switch (cfg->target) { case Target::kAll: // Always rename. break; diff --git a/src/transform/renamer.h b/src/transform/renamer.h index 55cf5dcd2e..33e9142d80 100644 --- a/src/transform/renamer.h +++ b/src/transform/renamer.h @@ -24,7 +24,7 @@ namespace tint { namespace transform { /// Renamer is a Transform that renames all the symbols in a program. -class Renamer : public Transform { +class Renamer : public Castable { public: /// Data is outputted by the Renamer transform. /// Data holds information about shader usage and constant buffer offsets. @@ -57,16 +57,27 @@ class Renamer : public Transform { }; /// Configuration options for the transform - struct Config { + struct Config : public Castable { + /// Constructor + /// @param tgt the targets to rename + explicit Config(Target tgt); + + /// Copy constructor + Config(const Config&); + + /// Destructor + ~Config() override; + /// The targets to rename - Target target = Target::kAll; + Target const target = Target::kAll; }; - /// Constructor using a default configuration + /// Constructor using a the configuration provided in the input Data Renamer(); /// Constructor /// @param config the configuration for the transform + /// [DEPRECATED] Pass Config as input Data explicit Renamer(const Config& config); /// Destructor @@ -79,7 +90,7 @@ class Renamer : public Transform { Output Run(const Program* program, const DataMap& data = {}) override; private: - Config const cfg_; + Config const deprecated_cfg_; }; } // namespace transform diff --git a/src/transform/renamer_test.cc b/src/transform/renamer_test.cc index ca2bc446a2..73467a08fb 100644 --- a/src/transform/renamer_test.cc +++ b/src/transform/renamer_test.cc @@ -207,8 +207,9 @@ fn frag_main() { } )"; - Renamer::Config config{Renamer::Target::kHlslKeywords}; - auto got = Run(src, std::make_unique(config)); + DataMap inputs; + inputs.Add(Renamer::Target::kHlslKeywords); + auto got = Run(src, inputs); EXPECT_EQ(expect, str(got)); } @@ -231,8 +232,9 @@ fn frag_main() { } )"; - Renamer::Config config{Renamer::Target::kMslKeywords}; - auto got = Run(src, std::make_unique(config)); + DataMap inputs; + inputs.Add(Renamer::Target::kMslKeywords); + auto got = Run(src, inputs); EXPECT_EQ(expect, str(got)); } diff --git a/src/transform/simplify.cc b/src/transform/simplify.cc index c9ac2ef90f..c007d35fc4 100644 --- a/src/transform/simplify.cc +++ b/src/transform/simplify.cc @@ -25,6 +25,8 @@ #include "src/sem/variable.h" #include "src/utils/scoped_assignment.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Simplify); + namespace tint { namespace transform { @@ -32,10 +34,7 @@ Simplify::Simplify() = default; Simplify::~Simplify() = default; -Output Simplify::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void Simplify::Run(CloneContext& ctx, const DataMap&, DataMap&) { ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* { if (auto* outer = expr->As()) { if (auto* inner = outer->expr()->As()) { @@ -55,8 +54,6 @@ Output Simplify::Run(const Program* in, const DataMap&) { }); ctx.Clone(); - - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/simplify.h b/src/transform/simplify.h index f6fe5b5ec6..084e0baa7c 100644 --- a/src/transform/simplify.h +++ b/src/transform/simplify.h @@ -28,7 +28,7 @@ namespace transform { /// Simplify currently optimizes the following: /// `&(*(expr))` => `expr` /// `*(&(expr))` => `expr` -class Simplify : public Transform { +class Simplify : public Castable { public: /// Constructor Simplify(); @@ -36,11 +36,14 @@ class Simplify : public Transform { /// Destructor ~Simplify() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/single_entry_point.cc b/src/transform/single_entry_point.cc index 15afa5f3d7..c8ece461a5 100644 --- a/src/transform/single_entry_point.cc +++ b/src/transform/single_entry_point.cc @@ -21,6 +21,7 @@ #include "src/sem/function.h" #include "src/sem/variable.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint); TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config); namespace tint { @@ -30,69 +31,63 @@ SingleEntryPoint::SingleEntryPoint() = default; SingleEntryPoint::~SingleEntryPoint() = default; -Output SingleEntryPoint::Run(const Program* in, const DataMap& data) { - ProgramBuilder out; - - auto* cfg = data.Get(); +void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { + auto* cfg = inputs.Get(); if (cfg == nullptr) { - out.Diagnostics().add_error(diag::System::Transform, - "missing transform data for SingleEntryPoint"); - return Output(Program(std::move(out))); + ctx.dst->Diagnostics().add_error( + diag::System::Transform, "missing transform data for SingleEntryPoint"); + return; } // Find the target entry point. ast::Function* entry_point = nullptr; - for (auto* f : in->AST().Functions()) { + for (auto* f : ctx.src->AST().Functions()) { if (!f->IsEntryPoint()) { continue; } - if (in->Symbols().NameFor(f->symbol()) == cfg->entry_point_name) { + if (ctx.src->Symbols().NameFor(f->symbol()) == cfg->entry_point_name) { entry_point = f; break; } } if (entry_point == nullptr) { - out.Diagnostics().add_error( + ctx.dst->Diagnostics().add_error( diag::System::Transform, "entry point '" + cfg->entry_point_name + "' not found"); - return Output(Program(std::move(out))); + return; } - CloneContext ctx(&out, in); - - auto* sem = in->Sem().Get(entry_point); + auto& sem = ctx.src->Sem(); // Build set of referenced module-scope variables for faster lookups later. std::unordered_set referenced_vars; - for (auto* var : sem->ReferencedModuleVariables()) { + for (auto* var : sem.Get(entry_point)->ReferencedModuleVariables()) { referenced_vars.emplace(var->Declaration()); } // Clone any module-scope variables, types, and functions that are statically // referenced by the target entry point. - for (auto* decl : in->AST().GlobalDeclarations()) { + for (auto* decl : ctx.src->AST().GlobalDeclarations()) { if (auto* ty = decl->As()) { // TODO(jrprice): Strip unused types. - out.AST().AddTypeDecl(ctx.Clone(ty)); + ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); } else if (auto* var = decl->As()) { if (var->is_const() || referenced_vars.count(var)) { - out.AST().AddGlobalVariable(ctx.Clone(var)); + ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); } } else if (auto* func = decl->As()) { - if (in->Sem().Get(func)->HasAncestorEntryPoint(entry_point->symbol())) { - out.AST().AddFunction(ctx.Clone(func)); + if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol())) { + ctx.dst->AST().AddFunction(ctx.Clone(func)); } } else { - TINT_UNREACHABLE(Transform, out.Diagnostics()) + TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << "unhandled global declaration: " << decl->TypeInfo().name; - return Output(Program(std::move(out))); + return; } } // Clone the entry point. - out.AST().AddFunction(ctx.Clone(entry_point)); - - return Output(Program(std::move(out))); + ctx.dst->AST().AddFunction(ctx.Clone(entry_point)); } SingleEntryPoint::Config::Config(std::string entry_point) diff --git a/src/transform/single_entry_point.h b/src/transform/single_entry_point.h index c2cd16c1b6..447b5c4da7 100644 --- a/src/transform/single_entry_point.h +++ b/src/transform/single_entry_point.h @@ -26,7 +26,7 @@ namespace transform { /// /// All module-scope variables, types, and functions that are not used by the /// target entry point will also be removed. -class SingleEntryPoint : public Transform { +class SingleEntryPoint : public Castable { public: /// Configuration options for the transform struct Config : public Castable { @@ -54,11 +54,14 @@ class SingleEntryPoint : public Transform { /// Destructor ~SingleEntryPoint() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; }; } // namespace transform diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 874c1142bf..b88a4271e2 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -34,6 +34,7 @@ #include "src/transform/simplify.h" #include "src/transform/zero_init_workgroup_memory.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Spirv); TINT_INSTANTIATE_TYPEINFO(tint::transform::Spirv::Config); namespace tint { @@ -70,6 +71,7 @@ Output Spirv::Run(const Program* in, const DataMap& data) { } ctx2.Clone(); + out2.SetTransformApplied(this); return Output{Program(std::move(out2))}; } diff --git a/src/transform/spirv.h b/src/transform/spirv.h index febf5bb2a1..dba32b5ec9 100644 --- a/src/transform/spirv.h +++ b/src/transform/spirv.h @@ -29,7 +29,7 @@ namespace transform { /// Spirv is a transform used to sanitize a Program for use with the Spirv /// writer. Passing a non-sanitized Program to the Spirv writer will result in /// undefined behavior. -class Spirv : public Transform { +class Spirv : public Castable { public: /// Configuration options for the transform. struct Config : public Castable { diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h index dcd11b0517..ed21a60415 100644 --- a/src/transform/test_helper.h +++ b/src/transform/test_helper.h @@ -32,32 +32,6 @@ namespace transform { template class TransformTestBase : public BASE { public: - /// Transforms and returns the WGSL source `in`, transformed using - /// `transforms`. - /// @param in the input WGSL source - /// @param transforms the list of transforms to apply - /// @param data the optional DataMap to pass to Transform::Run() - /// @return the transformed output - Output Run(std::string in, - std::vector> transforms, - const DataMap& data = {}) { - auto file = std::make_unique("test", in); - auto program = reader::wgsl::Parse(file.get()); - - // Keep this pointer alive after Transform() returns - files_.emplace_back(std::move(file)); - - if (!program.IsValid()) { - return Output(std::move(program)); - } - - Manager manager; - for (auto& transform : transforms) { - manager.append(std::move(transform)); - } - return manager.Run(&program, data); - } - /// Transforms and returns the WGSL source `in`, transformed using /// `transform`. /// @param transform the transform to apply @@ -77,9 +51,24 @@ class TransformTestBase : public BASE { /// @param in the input WGSL source /// @param data the optional DataMap to pass to Transform::Run() /// @return the transformed output - template + template Output Run(std::string in, const DataMap& data = {}) { - return Run(std::move(in), std::make_unique(), data); + auto file = std::make_unique("test", in); + auto program = reader::wgsl::Parse(file.get()); + + // Keep this pointer alive after Transform() returns + files_.emplace_back(std::move(file)); + + if (!program.IsValid()) { + return Output(std::move(program)); + } + + Manager manager; + for (auto* transform_ptr : + std::initializer_list{new TRANSFORMS()...}) { + manager.append(std::unique_ptr(transform_ptr)); + } + return manager.Run(&program, data); } /// @param output the output of the transform diff --git a/src/transform/transform.cc b/src/transform/transform.cc index 914b167b25..a6b46bd9a5 100644 --- a/src/transform/transform.cc +++ b/src/transform/transform.cc @@ -15,11 +15,13 @@ #include "src/transform/transform.h" #include +#include #include "src/program_builder.h" #include "src/sem/atomic_type.h" #include "src/sem/reference_type.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform); TINT_INSTANTIATE_TYPEINFO(tint::transform::Data); namespace tint { @@ -40,6 +42,35 @@ Output::Output(Program&& p) : program(std::move(p)) {} Transform::Transform() = default; Transform::~Transform() = default; +Output Transform::Run(const Program* program, const DataMap& data /* = {} */) { + ProgramBuilder builder; + CloneContext ctx(&builder, program); + Output output; + Run(ctx, data, output.data); + builder.SetTransformApplied(this); + output.program = Program(std::move(builder)); + return output; +} + +void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) { + TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics()) + << "Transform::Run() unimplemented for " << TypeInfo().name; +} + +bool Transform::Requires(CloneContext& ctx, + std::initializer_list deps) { + for (auto* dep : deps) { + if (!ctx.src->HasTransformApplied(dep)) { + ctx.dst->Diagnostics().add_error( + diag::System::Transform, std::string(TypeInfo().name) + + " depends on " + std::string(dep->name) + + " but the dependency was not run"); + return false; + } + } + return true; +} + ast::Function* Transform::CloneWithStatementsAtStart( CloneContext* ctx, ast::Function* in, diff --git a/src/transform/transform.h b/src/transform/transform.h index 3506af106a..b589ab5f7f 100644 --- a/src/transform/transform.h +++ b/src/transform/transform.h @@ -19,6 +19,7 @@ #include #include +#include "src/castable.h" #include "src/program.h" namespace tint { @@ -145,20 +146,45 @@ class Output { }; /// Interface for Program transforms -class Transform { +class Transform : public Castable { public: /// Constructor Transform(); /// Destructor - virtual ~Transform(); + ~Transform() override; /// Runs the transform on `program`, returning the transformation result. /// @param program the source program to transform /// @param data optional extra transform-specific input data /// @returns the transformation result - virtual Output Run(const Program* program, const DataMap& data = {}) = 0; + virtual Output Run(const Program* program, const DataMap& data = {}); protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs); + + /// Requires appends an error diagnostic to `ctx.dst` if the template type + /// transforms were not already run on `ctx.src`. + /// @param ctx the CloneContext + /// @returns true if all dependency transforms have been run + template + bool Requires(CloneContext& ctx) { + return Requires(ctx, {&::tint::TypeInfo::Of()...}); + } + + /// Requires appends an error diagnostic to `ctx.dst` if the list of + /// Transforms were not already run on `ctx.src`. + /// @param ctx the CloneContext + /// @param deps the list of Transform TypeInfos + /// @returns true if all dependency transforms have been run + bool Requires(CloneContext& ctx, + std::initializer_list deps); + /// Clones the function `in` adding `statements` to the beginning of the /// cloned function body. /// @param ctx the clone context diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index e06dd7c348..1161de9bee 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -24,6 +24,7 @@ #include "src/sem/variable.h" #include "src/utils/get_or_create.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling); TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config); namespace tint { @@ -456,21 +457,20 @@ struct State { VertexPulling::VertexPulling() = default; VertexPulling::~VertexPulling() = default; -Output VertexPulling::Run(const Program* in, const DataMap& data) { - ProgramBuilder out; - +void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { auto cfg = cfg_; - if (auto* cfg_data = data.Get()) { + if (auto* cfg_data = inputs.Get()) { cfg = *cfg_data; } // Find entry point - auto* func = in->AST().Functions().Find( - in->Symbols().Get(cfg.entry_point_name), ast::PipelineStage::kVertex); + auto* func = ctx.src->AST().Functions().Find( + ctx.src->Symbols().Get(cfg.entry_point_name), + ast::PipelineStage::kVertex); if (func == nullptr) { - out.Diagnostics().add_error(diag::System::Transform, - "Vertex stage entry point not found"); - return Output(Program(std::move(out))); + ctx.dst->Diagnostics().add_error(diag::System::Transform, + "Vertex stage entry point not found"); + return; } // TODO(idanr): Need to check shader locations in descriptor cover all @@ -479,15 +479,11 @@ Output VertexPulling::Run(const Program* in, const DataMap& data) { // TODO(idanr): Make sure we covered all error cases, to guarantee the // following stages will pass - CloneContext ctx(&out, in); - State state{ctx, cfg}; state.AddVertexStorageBuffers(); state.Process(func); ctx.Clone(); - - return Output(Program(std::move(out))); } VertexPulling::Config::Config() = default; diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h index 15ea211f00..6affca4865 100644 --- a/src/transform/vertex_pulling.h +++ b/src/transform/vertex_pulling.h @@ -130,7 +130,7 @@ using VertexStateDescriptor = std::vector; /// code, but these are types that the data may arrive as. We need to convert /// these smaller types into the base types such as `f32` and `u32` for the /// shader to use. -class VertexPulling : public Transform { +class VertexPulling : public Castable { public: /// Configuration options for the transform struct Config : public Castable { @@ -164,11 +164,14 @@ class VertexPulling : public Transform { /// Destructor ~VertexPulling() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; private: Config cfg_; diff --git a/src/transform/wrap_arrays_in_structs.cc b/src/transform/wrap_arrays_in_structs.cc index b1715aa15f..d7c7396f93 100644 --- a/src/transform/wrap_arrays_in_structs.cc +++ b/src/transform/wrap_arrays_in_structs.cc @@ -21,6 +21,8 @@ #include "src/sem/expression.h" #include "src/utils/get_or_create.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::WrapArraysInStructs); + namespace tint { namespace transform { @@ -33,10 +35,7 @@ WrapArraysInStructs::WrapArraysInStructs() = default; WrapArraysInStructs::~WrapArraysInStructs() = default; -Output WrapArraysInStructs::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) { auto& sem = ctx.src->Sem(); std::unordered_map wrapped_arrays; @@ -60,8 +59,8 @@ Output WrapArraysInStructs::Run(const Program* in, const DataMap&) { // Fix up array accessors so `a[1]` becomes `a.arr[1]` ctx.ReplaceAll([&](ast::ArrayAccessorExpression* accessor) -> ast::ArrayAccessorExpression* { - if (auto* array = - As(sem.Get(accessor->array())->Type()->UnwrapRef())) { + if (auto* array = ::tint::As( + sem.Get(accessor->array())->Type()->UnwrapRef())) { if (wrapper(array)) { // Array is wrapped in a structure. Emit a member accessor to get // to the actual array. @@ -76,7 +75,8 @@ Output WrapArraysInStructs::Run(const Program* in, const DataMap&) { // Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))` ctx.ReplaceAll([&](ast::TypeConstructorExpression* ctor) -> ast::Expression* { - if (auto* array = As(sem.Get(ctor)->Type()->UnwrapRef())) { + if (auto* array = + ::tint::As(sem.Get(ctor)->Type()->UnwrapRef())) { if (auto w = wrapper(array)) { // Wrap the array type constructor with another constructor for // the wrapper @@ -91,8 +91,6 @@ Output WrapArraysInStructs::Run(const Program* in, const DataMap&) { }); ctx.Clone(); - - return Output(Program(std::move(out))); } WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray( diff --git a/src/transform/wrap_arrays_in_structs.h b/src/transform/wrap_arrays_in_structs.h index 2dda1eeef2..6eb4bc164b 100644 --- a/src/transform/wrap_arrays_in_structs.h +++ b/src/transform/wrap_arrays_in_structs.h @@ -36,7 +36,7 @@ namespace transform { /// wrapping. /// This transform helps with backends that cannot directly return arrays or use /// them as parameters. -class WrapArraysInStructs : public Transform { +class WrapArraysInStructs : public Castable { public: /// Constructor WrapArraysInStructs(); @@ -44,11 +44,14 @@ class WrapArraysInStructs : public Transform { /// Destructor ~WrapArraysInStructs() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; private: struct WrappedArrayInfo { diff --git a/src/transform/zero_init_workgroup_memory.cc b/src/transform/zero_init_workgroup_memory.cc index 287018c1b1..80ba53a780 100644 --- a/src/transform/zero_init_workgroup_memory.cc +++ b/src/transform/zero_init_workgroup_memory.cc @@ -23,6 +23,8 @@ #include "src/sem/variable.h" #include "src/utils/get_or_create.h" +TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory); + namespace tint { namespace transform { @@ -111,13 +113,10 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; -Output ZeroInitWorkgroupMemory::Run(const Program* in, const DataMap&) { - ProgramBuilder out; - CloneContext ctx(&out, in); - +void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) { auto& sem = ctx.src->Sem(); - for (auto* ast_func : in->AST().Functions()) { + for (auto* ast_func : ctx.src->AST().Functions()) { if (!ast_func->IsEntryPoint()) { continue; } @@ -192,8 +191,6 @@ Output ZeroInitWorkgroupMemory::Run(const Program* in, const DataMap&) { } ctx.Clone(); - - return Output(Program(std::move(out))); } } // namespace transform diff --git a/src/transform/zero_init_workgroup_memory.h b/src/transform/zero_init_workgroup_memory.h index bf846c746e..1644b4d08e 100644 --- a/src/transform/zero_init_workgroup_memory.h +++ b/src/transform/zero_init_workgroup_memory.h @@ -23,7 +23,8 @@ namespace transform { /// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry /// points to zero-initialize workgroup memory used by that entry point (and all /// transitive functions called by that entry point) -class ZeroInitWorkgroupMemory : public Transform { +class ZeroInitWorkgroupMemory + : public Castable { public: /// Constructor ZeroInitWorkgroupMemory(); @@ -31,11 +32,14 @@ class ZeroInitWorkgroupMemory : public Transform { /// Destructor ~ZeroInitWorkgroupMemory() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) override; + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; private: struct State; diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index ce1aea88e3..f7677e4a55 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -36,6 +36,7 @@ #include "src/sem/struct.h" #include "src/sem/variable.h" #include "src/transform/calculate_array_length.h" +#include "src/transform/hlsl.h" #include "src/utils/scoped_assignment.h" #include "src/writer/append_vector.h" #include "src/writer/float_to_string.h" @@ -114,6 +115,14 @@ GeneratorImpl::GeneratorImpl(const Program* program) GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate(std::ostream& out) { + if (!builder_.HasTransformApplied()) { + diagnostics_.add_error( + diag::System::Writer, + "HLSL writer requires the transform::Hlsl sanitizer to have been " + "applied to the input program"); + return false; + } + std::stringstream pending; const TypeInfo* last_kind = nullptr; diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc index 341b2d155c..adf73e43bf 100644 --- a/src/writer/hlsl/generator_impl_test.cc +++ b/src/writer/hlsl/generator_impl_test.cc @@ -21,6 +21,16 @@ namespace { using HlslGeneratorImplTest = TestHelper; +TEST_F(HlslGeneratorImplTest, ErrorIfSanitizerNotRun) { + auto program = std::make_unique(std::move(*this)); + GeneratorImpl gen(program.get()); + EXPECT_FALSE(gen.Generate(out)); + EXPECT_EQ( + gen.error(), + "error: HLSL writer requires the transform::Hlsl sanitizer to have been " + "applied to the input program"); +} + TEST_F(HlslGeneratorImplTest, Generate) { Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, ast::DecorationList{}); diff --git a/src/writer/hlsl/test_helper.h b/src/writer/hlsl/test_helper.h index 6ccbb3ba39..26c55630f0 100644 --- a/src/writer/hlsl/test_helper.h +++ b/src/writer/hlsl/test_helper.h @@ -44,15 +44,17 @@ class TestHelperBase : public BODY, public ProgramBuilder { if (gen_) { return *gen_; } - diag::Formatter formatter; + // Fake that the HLSL sanitizer has been applied, so that we can unit test + // the writer without it erroring. + SetTransformApplied(); [&]() { ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" - << formatter.format(Diagnostics()); + << diag::Formatter().format(Diagnostics()); }(); program = std::make_unique(std::move(*this)); [&]() { ASSERT_TRUE(program->IsValid()) - << formatter.format(program->Diagnostics()); + << diag::Formatter().format(program->Diagnostics()); }(); gen_ = std::make_unique(program.get()); return *gen_; diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 5a89d116fd..db3665ef2f 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -50,6 +50,7 @@ #include "src/sem/variable.h" #include "src/sem/vector_type.h" #include "src/sem/void_type.h" +#include "src/transform/msl.h" #include "src/utils/scoped_assignment.h" #include "src/writer/float_to_string.h" @@ -75,6 +76,14 @@ GeneratorImpl::GeneratorImpl(const Program* program) GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate() { + if (!program_->HasTransformApplied()) { + diagnostics_.add_error( + diag::System::Writer, + "MSL writer requires the transform::Msl sanitizer to have been " + "applied to the input program"); + return false; + } + out_ << "#include " << std::endl << std::endl; out_ << "using namespace metal;" << std::endl; diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc index 4cb99f83b1..46bdef568b 100644 --- a/src/writer/msl/generator_impl_test.cc +++ b/src/writer/msl/generator_impl_test.cc @@ -22,6 +22,16 @@ namespace { using MslGeneratorImplTest = TestHelper; +TEST_F(MslGeneratorImplTest, ErrorIfSanitizerNotRun) { + auto program = std::make_unique(std::move(*this)); + GeneratorImpl gen(program.get()); + EXPECT_FALSE(gen.Generate()); + EXPECT_EQ( + gen.error(), + "error: MSL writer requires the transform::Msl sanitizer to have been " + "applied to the input program"); +} + TEST_F(MslGeneratorImplTest, Generate) { Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, ast::DecorationList{ diff --git a/src/writer/msl/test_helper.h b/src/writer/msl/test_helper.h index 3ae78c731d..7400467347 100644 --- a/src/writer/msl/test_helper.h +++ b/src/writer/msl/test_helper.h @@ -43,6 +43,9 @@ class TestHelperBase : public BASE, public ProgramBuilder { if (gen_) { return *gen_; } + // Fake that the MSL sanitizer has been applied, so that we can unit test + // the writer without it erroring. + SetTransformApplied(); [&]() { ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" << diag::Formatter().format(Diagnostics()); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index e6bc0a84fa..4af004a6d3 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -35,6 +35,7 @@ #include "src/sem/struct.h" #include "src/sem/variable.h" #include "src/sem/vector_type.h" +#include "src/transform/spirv.h" #include "src/utils/get_or_create.h" #include "src/writer/append_vector.h" @@ -258,6 +259,13 @@ Builder::Builder(const Program* program) Builder::~Builder() = default; bool Builder::Build() { + if (!builder_.HasTransformApplied()) { + error_ = + "SPIR-V writer requires the transform::Spirv sanitizer to have been " + "applied to the input program"; + return false; + } + push_capability(SpvCapabilityShader); push_memory_model(spv::Op::OpMemoryModel, diff --git a/src/writer/spirv/builder_test.cc b/src/writer/spirv/builder_test.cc index 20f23b489b..acf48d6980 100644 --- a/src/writer/spirv/builder_test.cc +++ b/src/writer/spirv/builder_test.cc @@ -22,6 +22,16 @@ namespace { using BuilderTest = TestHelper; +TEST_F(BuilderTest, ErrorIfSanitizerNotRun) { + auto program = std::make_unique(std::move(*this)); + spirv::Builder b(program.get()); + EXPECT_FALSE(b.Build()); + EXPECT_EQ( + b.error(), + "SPIR-V writer requires the transform::Spirv sanitizer to have been " + "applied to the input program"); +} + TEST_F(BuilderTest, InsertsPreamble) { spirv::Builder& b = Build(); diff --git a/src/writer/spirv/test_helper.h b/src/writer/spirv/test_helper.h index 0fe99e7aaa..dd8f5c2230 100644 --- a/src/writer/spirv/test_helper.h +++ b/src/writer/spirv/test_helper.h @@ -43,6 +43,9 @@ class TestHelperBase : public ProgramBuilder, public BASE { if (spirv_builder) { return *spirv_builder; } + // Fake that the SPIR-V sanitizer has been applied, so that we can unit test + // the writer without it erroring. + SetTransformApplied(); [&]() { ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" << diag::Formatter().format(Diagnostics());