mirror of
				https://github.com/encounter/dawn-cmake.git
				synced 2025-10-27 12:10:29 +00:00 
			
		
		
		
	tint/transform: Refactor transforms
Replace the ShouldRun() method with Apply() which will do the transformation if it needs to be done, otherwise returns 'SkipTransform'. This reduces a bunch of duplicated scanning between the old ShouldRun() and Transform(). This change also adjusts code style to make the transforms more consistent. Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681 Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
		
							parent
							
								
									de6db384aa
								
							
						
					
					
						commit
						c6b381495d
					
				| @ -15,6 +15,7 @@ | |||||||
| #include "src/tint/fuzzers/shuffle_transform.h" | #include "src/tint/fuzzers/shuffle_transform.h" | ||||||
| 
 | 
 | ||||||
| #include <random> | #include <random> | ||||||
|  | #include <utility> | ||||||
| 
 | 
 | ||||||
| #include "src/tint/program_builder.h" | #include "src/tint/program_builder.h" | ||||||
| 
 | 
 | ||||||
| @ -22,15 +23,21 @@ namespace tint::fuzzers { | |||||||
| 
 | 
 | ||||||
| ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {} | ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {} | ||||||
| 
 | 
 | ||||||
| void ShuffleTransform::Run(CloneContext& ctx, | transform::Transform::ApplyResult ShuffleTransform::Apply(const Program* src, | ||||||
|                            const tint::transform::DataMap&, |                                                           const transform::DataMap&, | ||||||
|                            tint::transform::DataMap&) const { |                                                           transform::DataMap&) const { | ||||||
|     auto decls = ctx.src->AST().GlobalDeclarations(); |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     auto decls = src->AST().GlobalDeclarations(); | ||||||
|     auto rng = std::mt19937_64{seed_}; |     auto rng = std::mt19937_64{seed_}; | ||||||
|     std::shuffle(std::begin(decls), std::end(decls), rng); |     std::shuffle(std::begin(decls), std::end(decls), rng); | ||||||
|     for (auto* decl : decls) { |     for (auto* decl : decls) { | ||||||
|         ctx.dst->AST().AddGlobalDeclaration(ctx.Clone(decl)); |         b.AST().AddGlobalDeclaration(ctx.Clone(decl)); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::fuzzers
 | }  // namespace tint::fuzzers
 | ||||||
|  | |||||||
| @ -20,16 +20,16 @@ | |||||||
| namespace tint::fuzzers { | namespace tint::fuzzers { | ||||||
| 
 | 
 | ||||||
| /// ShuffleTransform reorders the module scope declarations into a random order
 | /// ShuffleTransform reorders the module scope declarations into a random order
 | ||||||
| class ShuffleTransform : public tint::transform::Transform { | class ShuffleTransform : public transform::Transform { | ||||||
|   public: |   public: | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param seed the random seed to use for the shuffling
 |     /// @param seed the random seed to use for the shuffling
 | ||||||
|     explicit ShuffleTransform(size_t seed); |     explicit ShuffleTransform(size_t seed); | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc transform::Transform::Apply
 | ||||||
|     void Run(CloneContext& ctx, |     ApplyResult Apply(const Program* program, | ||||||
|              const tint::transform::DataMap&, |                       const transform::DataMap& inputs, | ||||||
|              tint::transform::DataMap&) const override; |                       transform::DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     size_t seed_; |     size_t seed_; | ||||||
|  | |||||||
| @ -23,6 +23,7 @@ | |||||||
| 
 | 
 | ||||||
| #include "src/tint/ast/access.h" | #include "src/tint/ast/access.h" | ||||||
| #include "src/tint/ast/address_space.h" | #include "src/tint/ast/address_space.h" | ||||||
|  | #include "src/tint/ast/parameter.h" | ||||||
| #include "src/tint/sem/binding_point.h" | #include "src/tint/sem/binding_point.h" | ||||||
| #include "src/tint/sem/expression.h" | #include "src/tint/sem/expression.h" | ||||||
| #include "src/tint/sem/parameter_usage.h" | #include "src/tint/sem/parameter_usage.h" | ||||||
| @ -212,6 +213,11 @@ class Parameter final : public Castable<Parameter, Variable> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~Parameter() override; |     ~Parameter() override; | ||||||
| 
 | 
 | ||||||
|  |     /// @returns the AST declaration node
 | ||||||
|  |     const ast::Parameter* Declaration() const { | ||||||
|  |         return static_cast<const ast::Parameter*>(Variable::Declaration()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /// @return the index of the parmeter in the function
 |     /// @return the index of the parmeter in the function
 | ||||||
|     uint32_t Index() const { return index_; } |     uint32_t Index() const { return index_; } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -31,21 +31,29 @@ AddBlockAttribute::AddBlockAttribute() = default; | |||||||
| 
 | 
 | ||||||
| AddBlockAttribute::~AddBlockAttribute() = default; | AddBlockAttribute::~AddBlockAttribute() = default; | ||||||
| 
 | 
 | ||||||
| void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult AddBlockAttribute::Apply(const Program* src, | ||||||
|     auto& sem = ctx.src->Sem(); |                                                 const DataMap&, | ||||||
|  |                                                 DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|     // A map from a type in the source program to a block-decorated wrapper that contains it in the
 |     // A map from a type in the source program to a block-decorated wrapper that contains it in the
 | ||||||
|     // destination program.
 |     // destination program.
 | ||||||
|     utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs; |     utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs; | ||||||
| 
 | 
 | ||||||
|     // Process global 'var' declarations that are buffers.
 |     // Process global 'var' declarations that are buffers.
 | ||||||
|     for (auto* global : ctx.src->AST().GlobalVariables()) { |     bool made_changes = false; | ||||||
|  |     for (auto* global : src->AST().GlobalVariables()) { | ||||||
|         auto* var = sem.Get(global); |         auto* var = sem.Get(global); | ||||||
|         if (!ast::IsHostShareable(var->AddressSpace())) { |         if (!ast::IsHostShareable(var->AddressSpace())) { | ||||||
|             // Not declared in a host-sharable address space
 |             // Not declared in a host-sharable address space
 | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         made_changes = true; | ||||||
|  | 
 | ||||||
|         auto* ty = var->Type()->UnwrapRef(); |         auto* ty = var->Type()->UnwrapRef(); | ||||||
|         auto* str = ty->As<sem::Struct>(); |         auto* str = ty->As<sem::Struct>(); | ||||||
| 
 | 
 | ||||||
| @ -61,33 +69,36 @@ void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|             const char* kMemberName = "inner"; |             const char* kMemberName = "inner"; | ||||||
| 
 | 
 | ||||||
|             auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] { |             auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] { | ||||||
|                 auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(), |                 auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID()); | ||||||
|                                                                          ctx.dst->AllocateNodeID()); |                 auto wrapper_name = src->Symbols().NameFor(global->symbol) + "_block"; | ||||||
|                 auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block"; |                 auto* ret = b.create<ast::Struct>( | ||||||
|                 auto* ret = ctx.dst->create<ast::Struct>( |                     b.Symbols().New(wrapper_name), | ||||||
|                     ctx.dst->Symbols().New(wrapper_name), |                     utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))}, | ||||||
|                     utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))}, |  | ||||||
|                     utils::Vector{block}); |                     utils::Vector{block}); | ||||||
|                 ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret); |                 ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret); | ||||||
|                 return ret; |                 return ret; | ||||||
|             }); |             }); | ||||||
|             ctx.Replace(global->type, ctx.dst->ty.Of(wrapper)); |             ctx.Replace(global->type, b.ty.Of(wrapper)); | ||||||
| 
 | 
 | ||||||
|             // Insert a member accessor to get the original type from the wrapper at
 |             // Insert a member accessor to get the original type from the wrapper at
 | ||||||
|             // any usage of the original variable.
 |             // any usage of the original variable.
 | ||||||
|             for (auto* user : var->Users()) { |             for (auto* user : var->Users()) { | ||||||
|                 ctx.Replace(user->Declaration(), |                 ctx.Replace(user->Declaration(), | ||||||
|                             ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName)); |                             b.MemberAccessor(ctx.Clone(global->symbol), kMemberName)); | ||||||
|             } |             } | ||||||
|         } else { |         } else { | ||||||
|             // Add a block attribute to this struct directly.
 |             // Add a block attribute to this struct directly.
 | ||||||
|             auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(), |             auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID()); | ||||||
|                                                                      ctx.dst->AllocateNodeID()); |  | ||||||
|             ctx.InsertFront(str->Declaration()->attributes, block); |             ctx.InsertFront(str->Declaration()->attributes, block); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     if (!made_changes) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid) | AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid) | ||||||
|  | |||||||
| @ -53,14 +53,10 @@ class AddBlockAttribute final : public Castable<AddBlockAttribute, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~AddBlockAttribute() override; |     ~AddBlockAttribute() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -23,12 +23,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint); | |||||||
| using namespace tint::number_suffixes;  // NOLINT
 | using namespace tint::number_suffixes;  // NOLINT
 | ||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
|  | namespace { | ||||||
| 
 | 
 | ||||||
| AddEmptyEntryPoint::AddEmptyEntryPoint() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; |  | ||||||
| 
 |  | ||||||
| bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* func : program->AST().Functions()) { |     for (auto* func : program->AST().Functions()) { | ||||||
|         if (func->IsEntryPoint()) { |         if (func->IsEntryPoint()) { | ||||||
|             return false; |             return false; | ||||||
| @ -37,13 +34,30 @@ bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // namespace
 | ||||||
|     ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {}, | 
 | ||||||
|                   utils::Vector{ | AddEmptyEntryPoint::AddEmptyEntryPoint() = default; | ||||||
|                       ctx.dst->Stage(ast::PipelineStage::kCompute), | 
 | ||||||
|                       ctx.dst->WorkgroupSize(1_i), | AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; | ||||||
|                   }); | 
 | ||||||
|  | Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src, | ||||||
|  |                                                  const DataMap&, | ||||||
|  |                                                  DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {}, | ||||||
|  |            utils::Vector{ | ||||||
|  |                b.Stage(ast::PipelineStage::kCompute), | ||||||
|  |                b.WorkgroupSize(1_i), | ||||||
|  |            }); | ||||||
|  | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -27,19 +27,10 @@ class AddEmptyEntryPoint final : public Castable<AddEmptyEntryPoint, Transform> | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~AddEmptyEntryPoint() override; |     ~AddEmptyEntryPoint() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -31,13 +31,153 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result); | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* fn : program->AST().Functions()) { | ||||||
|  |         if (auto* sem_fn = program->Sem().Get(fn)) { | ||||||
|  |             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { | ||||||
|  |                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { | ||||||
|  |                     return true; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
| ArrayLengthFromUniform::ArrayLengthFromUniform() = default; | ArrayLengthFromUniform::ArrayLengthFromUniform() = default; | ||||||
| ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; | ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for this transform
 | /// PIMPL state for the transform
 | ||||||
| struct ArrayLengthFromUniform::State { | struct ArrayLengthFromUniform::State { | ||||||
|  |     /// Constructor
 | ||||||
|  |     /// @param program the source program
 | ||||||
|  |     /// @param in the input transform data
 | ||||||
|  |     /// @param out the output transform data
 | ||||||
|  |     explicit State(const Program* program, const DataMap& in, DataMap& out) | ||||||
|  |         : src(program), inputs(in), outputs(out) {} | ||||||
|  | 
 | ||||||
|  |     /// Runs the transform
 | ||||||
|  |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|  |         auto* cfg = inputs.Get<Config>(); | ||||||
|  |         if (cfg == nullptr) { | ||||||
|  |             b.Diagnostics().add_error(diag::System::Transform, | ||||||
|  |                                       "missing transform data for " + | ||||||
|  |                                           std::string(TypeInfo::Of<ArrayLengthFromUniform>().name)); | ||||||
|  |             return Program(std::move(b)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (!ShouldRun(ctx.src)) { | ||||||
|  |             return SkipTransform; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         const char* kBufferSizeMemberName = "buffer_size"; | ||||||
|  | 
 | ||||||
|  |         // Determine the size of the buffer size array.
 | ||||||
|  |         uint32_t max_buffer_size_index = 0; | ||||||
|  | 
 | ||||||
|  |         IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*, | ||||||
|  |                                            const sem::GlobalVariable* var) { | ||||||
|  |             auto binding = var->BindingPoint(); | ||||||
|  |             auto idx_itr = cfg->bindpoint_to_size_index.find(binding); | ||||||
|  |             if (idx_itr == cfg->bindpoint_to_size_index.end()) { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             if (idx_itr->second > max_buffer_size_index) { | ||||||
|  |                 max_buffer_size_index = idx_itr->second; | ||||||
|  |             } | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         // Get (or create, on first call) the uniform buffer that will receive the
 | ||||||
|  |         // size of each storage buffer in the module.
 | ||||||
|  |         const ast::Variable* buffer_size_ubo = nullptr; | ||||||
|  |         auto get_ubo = [&]() { | ||||||
|  |             if (!buffer_size_ubo) { | ||||||
|  |                 // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
 | ||||||
|  |                 // We do this because UBOs require an element stride that is 16-byte
 | ||||||
|  |                 // aligned.
 | ||||||
|  |                 auto* buffer_size_struct = b.Structure( | ||||||
|  |                     b.Sym(), utils::Vector{ | ||||||
|  |                                  b.Member(kBufferSizeMemberName, | ||||||
|  |                                           b.ty.array(b.ty.vec4(b.ty.u32()), | ||||||
|  |                                                      u32((max_buffer_size_index / 4) + 1))), | ||||||
|  |                              }); | ||||||
|  |                 buffer_size_ubo = | ||||||
|  |                     b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct), ast::AddressSpace::kUniform, | ||||||
|  |                                 b.Group(AInt(cfg->ubo_binding.group)), | ||||||
|  |                                 b.Binding(AInt(cfg->ubo_binding.binding))); | ||||||
|  |             } | ||||||
|  |             return buffer_size_ubo; | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         std::unordered_set<uint32_t> used_size_indices; | ||||||
|  | 
 | ||||||
|  |         IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, | ||||||
|  |                                            const sem::VariableUser* storage_buffer_sem, | ||||||
|  |                                            const sem::GlobalVariable* var) { | ||||||
|  |             auto binding = var->BindingPoint(); | ||||||
|  |             auto idx_itr = cfg->bindpoint_to_size_index.find(binding); | ||||||
|  |             if (idx_itr == cfg->bindpoint_to_size_index.end()) { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             uint32_t size_index = idx_itr->second; | ||||||
|  |             used_size_indices.insert(size_index); | ||||||
|  | 
 | ||||||
|  |             // Load the total storage buffer size from the UBO.
 | ||||||
|  |             uint32_t array_index = size_index / 4; | ||||||
|  |             auto* vec_expr = b.IndexAccessor( | ||||||
|  |                 b.MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index)); | ||||||
|  |             uint32_t vec_index = size_index % 4; | ||||||
|  |             auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index)); | ||||||
|  | 
 | ||||||
|  |             // Calculate actual array length
 | ||||||
|  |             //                total_storage_buffer_size - array_offset
 | ||||||
|  |             // array_length = ----------------------------------------
 | ||||||
|  |             //                             array_stride
 | ||||||
|  |             const ast::Expression* total_size = total_storage_buffer_size; | ||||||
|  |             auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); | ||||||
|  |             const sem::Array* array_type = nullptr; | ||||||
|  |             if (auto* str = storage_buffer_type->As<sem::Struct>()) { | ||||||
|  |                 // The variable is a struct, so subtract the byte offset of the array
 | ||||||
|  |                 // member.
 | ||||||
|  |                 auto* array_member_sem = str->Members().back(); | ||||||
|  |                 array_type = array_member_sem->Type()->As<sem::Array>(); | ||||||
|  |                 total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset())); | ||||||
|  |             } else if (auto* arr = storage_buffer_type->As<sem::Array>()) { | ||||||
|  |                 array_type = arr; | ||||||
|  |             } else { | ||||||
|  |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|  |                     << "expected form of arrayLength argument to be &array_var or " | ||||||
|  |                        "&struct_var.array_member"; | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             auto* array_length = b.Div(total_size, u32(array_type->Stride())); | ||||||
|  | 
 | ||||||
|  |             ctx.Replace(call_expr, array_length); | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         outputs.Add<Result>(used_size_indices); | ||||||
|  | 
 | ||||||
|  |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   private: | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The transform inputs
 | ||||||
|  |     const DataMap& inputs; | ||||||
|  |     /// The transform outputs
 | ||||||
|  |     DataMap& outputs; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
| 
 | 
 | ||||||
|     /// Iterate over all arrayLength() builtins that operate on
 |     /// Iterate over all arrayLength() builtins that operate on
 | ||||||
|     /// storage buffer variables.
 |     /// storage buffer variables.
 | ||||||
| @ -48,10 +188,10 @@ struct ArrayLengthFromUniform::State { | |||||||
|     /// sem::GlobalVariable for the storage buffer.
 |     /// sem::GlobalVariable for the storage buffer.
 | ||||||
|     template <typename F> |     template <typename F> | ||||||
|     void IterateArrayLengthOnStorageVar(F&& functor) { |     void IterateArrayLengthOnStorageVar(F&& functor) { | ||||||
|         auto& sem = ctx.src->Sem(); |         auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|         // Find all calls to the arrayLength() builtin.
 |         // Find all calls to the arrayLength() builtin.
 | ||||||
|         for (auto* node : ctx.src->ASTNodes().Objects()) { |         for (auto* node : src->ASTNodes().Objects()) { | ||||||
|             auto* call_expr = node->As<ast::CallExpression>(); |             auto* call_expr = node->As<ast::CallExpression>(); | ||||||
|             if (!call_expr) { |             if (!call_expr) { | ||||||
|                 continue; |                 continue; | ||||||
| @ -79,7 +219,7 @@ struct ArrayLengthFromUniform::State { | |||||||
|             //   arrayLength(&array_var)
 |             //   arrayLength(&array_var)
 | ||||||
|             auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>(); |             auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>(); | ||||||
|             if (!param || param->op != ast::UnaryOp::kAddressOf) { |             if (!param || param->op != ast::UnaryOp::kAddressOf) { | ||||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                     << "expected form of arrayLength argument to be &array_var or " |                     << "expected form of arrayLength argument to be &array_var or " | ||||||
|                        "&struct_var.array_member"; |                        "&struct_var.array_member"; | ||||||
|                 break; |                 break; | ||||||
| @ -90,7 +230,7 @@ struct ArrayLengthFromUniform::State { | |||||||
|             } |             } | ||||||
|             auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); |             auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); | ||||||
|             if (!storage_buffer_sem) { |             if (!storage_buffer_sem) { | ||||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                     << "expected form of arrayLength argument to be &array_var or " |                     << "expected form of arrayLength argument to be &array_var or " | ||||||
|                        "&struct_var.array_member"; |                        "&struct_var.array_member"; | ||||||
|                 break; |                 break; | ||||||
| @ -99,8 +239,7 @@ struct ArrayLengthFromUniform::State { | |||||||
|             // Get the index to use for the buffer size array.
 |             // Get the index to use for the buffer size array.
 | ||||||
|             auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable()); |             auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable()); | ||||||
|             if (!var) { |             if (!var) { | ||||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable"; | ||||||
|                     << "storage buffer is not a global variable"; |  | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             functor(call_expr, storage_buffer_sem, var); |             functor(call_expr, storage_buffer_sem, var); | ||||||
| @ -108,117 +247,10 @@ struct ArrayLengthFromUniform::State { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src, | ||||||
|     for (auto* fn : program->AST().Functions()) { |                                                      const DataMap& inputs, | ||||||
|         if (auto* sem_fn = program->Sem().Get(fn)) { |                                                      DataMap& outputs) const { | ||||||
|             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { |     return State{src, inputs, outputs}.Run(); | ||||||
|                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { |  | ||||||
|                     return true; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { |  | ||||||
|     auto* cfg = inputs.Get<Config>(); |  | ||||||
|     if (cfg == nullptr) { |  | ||||||
|         ctx.dst->Diagnostics().add_error( |  | ||||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     const char* kBufferSizeMemberName = "buffer_size"; |  | ||||||
| 
 |  | ||||||
|     // Determine the size of the buffer size array.
 |  | ||||||
|     uint32_t max_buffer_size_index = 0; |  | ||||||
| 
 |  | ||||||
|     State{ctx}.IterateArrayLengthOnStorageVar( |  | ||||||
|         [&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) { |  | ||||||
|             auto binding = var->BindingPoint(); |  | ||||||
|             auto idx_itr = cfg->bindpoint_to_size_index.find(binding); |  | ||||||
|             if (idx_itr == cfg->bindpoint_to_size_index.end()) { |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|             if (idx_itr->second > max_buffer_size_index) { |  | ||||||
|                 max_buffer_size_index = idx_itr->second; |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|     // Get (or create, on first call) the uniform buffer that will receive the
 |  | ||||||
|     // size of each storage buffer in the module.
 |  | ||||||
|     const ast::Variable* buffer_size_ubo = nullptr; |  | ||||||
|     auto get_ubo = [&]() { |  | ||||||
|         if (!buffer_size_ubo) { |  | ||||||
|             // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
 |  | ||||||
|             // We do this because UBOs require an element stride that is 16-byte
 |  | ||||||
|             // aligned.
 |  | ||||||
|             auto* buffer_size_struct = ctx.dst->Structure( |  | ||||||
|                 ctx.dst->Sym(), |  | ||||||
|                 utils::Vector{ |  | ||||||
|                     ctx.dst->Member(kBufferSizeMemberName, |  | ||||||
|                                     ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), |  | ||||||
|                                                       u32((max_buffer_size_index / 4) + 1))), |  | ||||||
|                 }); |  | ||||||
|             buffer_size_ubo = ctx.dst->GlobalVar(ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), |  | ||||||
|                                                  ast::AddressSpace::kUniform, |  | ||||||
|                                                  ctx.dst->Group(AInt(cfg->ubo_binding.group)), |  | ||||||
|                                                  ctx.dst->Binding(AInt(cfg->ubo_binding.binding))); |  | ||||||
|         } |  | ||||||
|         return buffer_size_ubo; |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     std::unordered_set<uint32_t> used_size_indices; |  | ||||||
| 
 |  | ||||||
|     State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, |  | ||||||
|                                                   const sem::VariableUser* storage_buffer_sem, |  | ||||||
|                                                   const sem::GlobalVariable* var) { |  | ||||||
|         auto binding = var->BindingPoint(); |  | ||||||
|         auto idx_itr = cfg->bindpoint_to_size_index.find(binding); |  | ||||||
|         if (idx_itr == cfg->bindpoint_to_size_index.end()) { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         uint32_t size_index = idx_itr->second; |  | ||||||
|         used_size_indices.insert(size_index); |  | ||||||
| 
 |  | ||||||
|         // Load the total storage buffer size from the UBO.
 |  | ||||||
|         uint32_t array_index = size_index / 4; |  | ||||||
|         auto* vec_expr = ctx.dst->IndexAccessor( |  | ||||||
|             ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index)); |  | ||||||
|         uint32_t vec_index = size_index % 4; |  | ||||||
|         auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, u32(vec_index)); |  | ||||||
| 
 |  | ||||||
|         // Calculate actual array length
 |  | ||||||
|         //                total_storage_buffer_size - array_offset
 |  | ||||||
|         // array_length = ----------------------------------------
 |  | ||||||
|         //                             array_stride
 |  | ||||||
|         const ast::Expression* total_size = total_storage_buffer_size; |  | ||||||
|         auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); |  | ||||||
|         const sem::Array* array_type = nullptr; |  | ||||||
|         if (auto* str = storage_buffer_type->As<sem::Struct>()) { |  | ||||||
|             // The variable is a struct, so subtract the byte offset of the array
 |  | ||||||
|             // member.
 |  | ||||||
|             auto* array_member_sem = str->Members().back(); |  | ||||||
|             array_type = array_member_sem->Type()->As<sem::Array>(); |  | ||||||
|             total_size = ctx.dst->Sub(total_storage_buffer_size, u32(array_member_sem->Offset())); |  | ||||||
|         } else if (auto* arr = storage_buffer_type->As<sem::Array>()) { |  | ||||||
|             array_type = arr; |  | ||||||
|         } else { |  | ||||||
|             TINT_ICE(Transform, ctx.dst->Diagnostics()) |  | ||||||
|                 << "expected form of arrayLength argument to be &array_var or " |  | ||||||
|                    "&struct_var.array_member"; |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         auto* array_length = ctx.dst->Div(total_size, u32(array_type->Stride())); |  | ||||||
| 
 |  | ||||||
|         ctx.Replace(call_expr, array_length); |  | ||||||
|     }); |  | ||||||
| 
 |  | ||||||
|     ctx.Clone(); |  | ||||||
| 
 |  | ||||||
|     outputs.Add<Result>(used_size_indices); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {} | ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {} | ||||||
|  | |||||||
| @ -100,22 +100,12 @@ class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Tra | |||||||
|         std::unordered_set<uint32_t> used_size_indices; |         std::unordered_set<uint32_t> used_size_indices; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     /// The PIMPL state for this transform
 |  | ||||||
|     struct State; |     struct State; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -28,7 +28,13 @@ using ArrayLengthFromUniformTest = TransformTest; | |||||||
| TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) { | TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) { | ||||||
|     auto* src = R"()"; |     auto* src = R"()"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src)); |     ArrayLengthFromUniform::Config cfg({0, 30u}); | ||||||
|  |     cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); | ||||||
|  | 
 | ||||||
|  |     DataMap data; | ||||||
|  |     data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); | ||||||
|  | 
 | ||||||
|  |     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) { | TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) { | ||||||
| @ -45,7 +51,13 @@ fn main() { | |||||||
| } | } | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src)); |     ArrayLengthFromUniform::Config cfg({0, 30u}); | ||||||
|  |     cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); | ||||||
|  | 
 | ||||||
|  |     DataMap data; | ||||||
|  |     data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); | ||||||
|  | 
 | ||||||
|  |     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) { | TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) { | ||||||
| @ -63,7 +75,13 @@ fn main() { | |||||||
| } | } | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src)); |     ArrayLengthFromUniform::Config cfg({0, 30u}); | ||||||
|  |     cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); | ||||||
|  | 
 | ||||||
|  |     DataMap data; | ||||||
|  |     data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); | ||||||
|  | 
 | ||||||
|  |     EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { | TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { | ||||||
|  | |||||||
| @ -40,19 +40,21 @@ BindingRemapper::Remappings::~Remappings() = default; | |||||||
| BindingRemapper::BindingRemapper() = default; | BindingRemapper::BindingRemapper() = default; | ||||||
| BindingRemapper::~BindingRemapper() = default; | BindingRemapper::~BindingRemapper() = default; | ||||||
| 
 | 
 | ||||||
| bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const { | Transform::ApplyResult BindingRemapper::Apply(const Program* src, | ||||||
|     if (auto* remappings = inputs.Get<Remappings>()) { |                                               const DataMap& inputs, | ||||||
|         return !remappings->binding_points.empty() || !remappings->access_controls.empty(); |                                               DataMap&) const { | ||||||
|     } |     ProgramBuilder b; | ||||||
|     return false; |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { |  | ||||||
|     auto* remappings = inputs.Get<Remappings>(); |     auto* remappings = inputs.Get<Remappings>(); | ||||||
|     if (!remappings) { |     if (!remappings) { | ||||||
|         ctx.dst->Diagnostics().add_error( |         b.Diagnostics().add_error(diag::System::Transform, | ||||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||||
|         return; |         return Program(std::move(b)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (remappings->binding_points.empty() && remappings->access_controls.empty()) { | ||||||
|  |         return SkipTransform; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // A set of post-remapped binding points that need to be decorated with a
 |     // A set of post-remapped binding points that need to be decorated with a
 | ||||||
| @ -62,11 +64,11 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | |||||||
|     if (remappings->allow_collisions) { |     if (remappings->allow_collisions) { | ||||||
|         // Scan for binding point collisions generated by this transform.
 |         // Scan for binding point collisions generated by this transform.
 | ||||||
|         // Populate all collisions in the `add_collision_attr` set.
 |         // Populate all collisions in the `add_collision_attr` set.
 | ||||||
|         for (auto* func_ast : ctx.src->AST().Functions()) { |         for (auto* func_ast : src->AST().Functions()) { | ||||||
|             if (!func_ast->IsEntryPoint()) { |             if (!func_ast->IsEntryPoint()) { | ||||||
|                 continue; |                 continue; | ||||||
|             } |             } | ||||||
|             auto* func = ctx.src->Sem().Get(func_ast); |             auto* func = src->Sem().Get(func_ast); | ||||||
|             std::unordered_map<sem::BindingPoint, int> binding_point_counts; |             std::unordered_map<sem::BindingPoint, int> binding_point_counts; | ||||||
|             for (auto* global : func->TransitivelyReferencedGlobals()) { |             for (auto* global : func->TransitivelyReferencedGlobals()) { | ||||||
|                 if (global->Declaration()->HasBindingPoint()) { |                 if (global->Declaration()->HasBindingPoint()) { | ||||||
| @ -90,9 +92,9 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     for (auto* var : ctx.src->AST().Globals<ast::Var>()) { |     for (auto* var : src->AST().Globals<ast::Var>()) { | ||||||
|         if (var->HasBindingPoint()) { |         if (var->HasBindingPoint()) { | ||||||
|             auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(var); |             auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var); | ||||||
| 
 | 
 | ||||||
|             // The original binding point
 |             // The original binding point
 | ||||||
|             BindingPoint from = global_sem->BindingPoint(); |             BindingPoint from = global_sem->BindingPoint(); | ||||||
| @ -106,8 +108,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | |||||||
|             auto bp_it = remappings->binding_points.find(from); |             auto bp_it = remappings->binding_points.find(from); | ||||||
|             if (bp_it != remappings->binding_points.end()) { |             if (bp_it != remappings->binding_points.end()) { | ||||||
|                 BindingPoint to = bp_it->second; |                 BindingPoint to = bp_it->second; | ||||||
|                 auto* new_group = ctx.dst->Group(AInt(to.group)); |                 auto* new_group = b.Group(AInt(to.group)); | ||||||
|                 auto* new_binding = ctx.dst->Binding(AInt(to.binding)); |                 auto* new_binding = b.Binding(AInt(to.binding)); | ||||||
| 
 | 
 | ||||||
|                 auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes); |                 auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes); | ||||||
|                 auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes); |                 auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes); | ||||||
| @ -122,37 +124,37 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | |||||||
|             if (ac_it != remappings->access_controls.end()) { |             if (ac_it != remappings->access_controls.end()) { | ||||||
|                 ast::Access ac = ac_it->second; |                 ast::Access ac = ac_it->second; | ||||||
|                 if (ac == ast::Access::kUndefined) { |                 if (ac == ast::Access::kUndefined) { | ||||||
|                     ctx.dst->Diagnostics().add_error( |                     b.Diagnostics().add_error( | ||||||
|                         diag::System::Transform, |                         diag::System::Transform, | ||||||
|                         "invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")"); |                         "invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")"); | ||||||
|                     return; |                     return Program(std::move(b)); | ||||||
|                 } |                 } | ||||||
|                 auto* sem = ctx.src->Sem().Get(var); |                 auto* sem = src->Sem().Get(var); | ||||||
|                 if (sem->AddressSpace() != ast::AddressSpace::kStorage) { |                 if (sem->AddressSpace() != ast::AddressSpace::kStorage) { | ||||||
|                     ctx.dst->Diagnostics().add_error( |                     b.Diagnostics().add_error( | ||||||
|                         diag::System::Transform, |                         diag::System::Transform, | ||||||
|                         "cannot apply access control to variable with address space " + |                         "cannot apply access control to variable with address space " + | ||||||
|                             std::string(utils::ToString(sem->AddressSpace()))); |                             std::string(utils::ToString(sem->AddressSpace()))); | ||||||
|                     return; |                     return Program(std::move(b)); | ||||||
|                 } |                 } | ||||||
|                 auto* ty = sem->Type()->UnwrapRef(); |                 auto* ty = sem->Type()->UnwrapRef(); | ||||||
|                 const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty); |                 const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty); | ||||||
|                 auto* new_var = |                 auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, | ||||||
|                     ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, |                                       var->declared_address_space, ac, ctx.Clone(var->initializer), | ||||||
|                                  var->declared_address_space, ac, ctx.Clone(var->initializer), |                                       ctx.Clone(var->attributes)); | ||||||
|                                  ctx.Clone(var->attributes)); |  | ||||||
|                 ctx.Replace(var, new_var); |                 ctx.Replace(var, new_var); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // Add `DisableValidationAttribute`s if required
 |             // Add `DisableValidationAttribute`s if required
 | ||||||
|             if (add_collision_attr.count(bp)) { |             if (add_collision_attr.count(bp)) { | ||||||
|                 auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision); |                 auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision); | ||||||
|                 ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute); |                 ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -67,19 +67,10 @@ class BindingRemapper final : public Castable<BindingRemapper, Transform> { | |||||||
|     BindingRemapper(); |     BindingRemapper(); | ||||||
|     ~BindingRemapper() override; |     ~BindingRemapper() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -23,12 +23,6 @@ namespace { | |||||||
| 
 | 
 | ||||||
| using BindingRemapperTest = TransformTest; | using BindingRemapperTest = TransformTest; | ||||||
| 
 | 
 | ||||||
| TEST_F(BindingRemapperTest, ShouldRunNoRemappings) { |  | ||||||
|     auto* src = R"()"; |  | ||||||
| 
 |  | ||||||
|     EXPECT_FALSE(ShouldRun<BindingRemapper>(src)); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) { | TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) { | ||||||
|     auto* src = R"()"; |     auto* src = R"()"; | ||||||
| 
 | 
 | ||||||
| @ -350,7 +344,7 @@ fn f() { | |||||||
| } | } | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     auto* expect = src; |     auto* expect = R"(error: missing transform data for tint::transform::BindingRemapper)"; | ||||||
| 
 | 
 | ||||||
|     auto got = Run<BindingRemapper>(src); |     auto got = Run<BindingRemapper>(src); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -29,7 +29,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::BuiltinPolyfill::Config); | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for the BuiltinPolyfill transform
 | /// PIMPL state for the transform
 | ||||||
| struct BuiltinPolyfill::State { | struct BuiltinPolyfill::State { | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param c the CloneContext
 |     /// @param c the CloneContext
 | ||||||
| @ -604,193 +604,100 @@ BuiltinPolyfill::BuiltinPolyfill() = default; | |||||||
| 
 | 
 | ||||||
| BuiltinPolyfill::~BuiltinPolyfill() = default; | BuiltinPolyfill::~BuiltinPolyfill() = default; | ||||||
| 
 | 
 | ||||||
| bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const { | Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src, | ||||||
|     if (auto* cfg = data.Get<Config>()) { |                                               const DataMap& data, | ||||||
|         auto builtins = cfg->builtins; |                                               DataMap&) const { | ||||||
|         auto& sem = program->Sem(); |  | ||||||
|         for (auto* node : program->ASTNodes().Objects()) { |  | ||||||
|             if (auto* call = sem.Get<sem::Call>(node)) { |  | ||||||
|                 if (auto* builtin = call->Target()->As<sem::Builtin>()) { |  | ||||||
|                     if (call->Stage() == sem::EvaluationStage::kConstant) { |  | ||||||
|                         continue;  // Don't polyfill @const expressions
 |  | ||||||
|                     } |  | ||||||
|                     switch (builtin->Type()) { |  | ||||||
|                         case sem::BuiltinType::kAcosh: |  | ||||||
|                             if (builtins.acosh != Level::kNone) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kAsinh: |  | ||||||
|                             if (builtins.asinh) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kAtanh: |  | ||||||
|                             if (builtins.atanh != Level::kNone) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kClamp: |  | ||||||
|                             if (builtins.clamp_int) { |  | ||||||
|                                 auto& sig = builtin->Signature(); |  | ||||||
|                                 return sig.parameters[0]->Type()->is_integer_scalar_or_vector(); |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kCountLeadingZeros: |  | ||||||
|                             if (builtins.count_leading_zeros) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kCountTrailingZeros: |  | ||||||
|                             if (builtins.count_trailing_zeros) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kExtractBits: |  | ||||||
|                             if (builtins.extract_bits != Level::kNone) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kFirstLeadingBit: |  | ||||||
|                             if (builtins.first_leading_bit) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kFirstTrailingBit: |  | ||||||
|                             if (builtins.first_trailing_bit) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kInsertBits: |  | ||||||
|                             if (builtins.insert_bits != Level::kNone) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kSaturate: |  | ||||||
|                             if (builtins.saturate) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kTextureSampleBaseClampToEdge: |  | ||||||
|                             if (builtins.texture_sample_base_clamp_to_edge_2d_f32) { |  | ||||||
|                                 auto& sig = builtin->Signature(); |  | ||||||
|                                 auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); |  | ||||||
|                                 if (auto* stex = tex->Type()->As<sem::SampledTexture>()) { |  | ||||||
|                                     return stex->type()->Is<sem::F32>(); |  | ||||||
|                                 } |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         case sem::BuiltinType::kQuantizeToF16: |  | ||||||
|                             if (builtins.quantize_to_vec_f16) { |  | ||||||
|                                 if (builtin->ReturnType()->Is<sem::Vector>()) { |  | ||||||
|                                     return true; |  | ||||||
|                                 } |  | ||||||
|                             } |  | ||||||
|                             break; |  | ||||||
|                         default: |  | ||||||
|                             break; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const { |  | ||||||
|     auto* cfg = data.Get<Config>(); |     auto* cfg = data.Get<Config>(); | ||||||
|     if (!cfg) { |     if (!cfg) { | ||||||
|         ctx.Clone(); |         return SkipTransform; | ||||||
|         return; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     std::unordered_map<const sem::Builtin*, Symbol> polyfills; |     auto& builtins = cfg->builtins; | ||||||
| 
 | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { |     utils::Hashmap<const sem::Builtin*, Symbol, 8> polyfills; | ||||||
|         auto builtins = cfg->builtins; | 
 | ||||||
|         State s{ctx, builtins}; |     ProgramBuilder b; | ||||||
|         if (auto* call = s.sem.Get<sem::Call>(expr)) { |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |     State s{ctx, builtins}; | ||||||
|  | 
 | ||||||
|  |     bool made_changes = false; | ||||||
|  |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|  |         if (auto* call = src->Sem().Get<sem::Call>(node)) { | ||||||
|             if (auto* builtin = call->Target()->As<sem::Builtin>()) { |             if (auto* builtin = call->Target()->As<sem::Builtin>()) { | ||||||
|                 if (call->Stage() == sem::EvaluationStage::kConstant) { |                 if (call->Stage() == sem::EvaluationStage::kConstant) { | ||||||
|                     return nullptr;  // Don't polyfill @const expressions
 |                     continue;  // Don't polyfill @const expressions
 | ||||||
|                 } |                 } | ||||||
|                 Symbol polyfill; |                 Symbol polyfill; | ||||||
|                 switch (builtin->Type()) { |                 switch (builtin->Type()) { | ||||||
|                     case sem::BuiltinType::kAcosh: |                     case sem::BuiltinType::kAcosh: | ||||||
|                         if (builtins.acosh != Level::kNone) { |                         if (builtins.acosh != Level::kNone) { | ||||||
|                             polyfill = utils::GetOrCreate( |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 polyfills, builtin, [&] { return s.acosh(builtin->ReturnType()); }); |                                 builtin, [&] { return s.acosh(builtin->ReturnType()); }); | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kAsinh: |                     case sem::BuiltinType::kAsinh: | ||||||
|                         if (builtins.asinh) { |                         if (builtins.asinh) { | ||||||
|                             polyfill = utils::GetOrCreate( |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 polyfills, builtin, [&] { return s.asinh(builtin->ReturnType()); }); |                                 builtin, [&] { return s.asinh(builtin->ReturnType()); }); | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kAtanh: |                     case sem::BuiltinType::kAtanh: | ||||||
|                         if (builtins.atanh != Level::kNone) { |                         if (builtins.atanh != Level::kNone) { | ||||||
|                             polyfill = utils::GetOrCreate( |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 polyfills, builtin, [&] { return s.atanh(builtin->ReturnType()); }); |                                 builtin, [&] { return s.atanh(builtin->ReturnType()); }); | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kClamp: |                     case sem::BuiltinType::kClamp: | ||||||
|                         if (builtins.clamp_int) { |                         if (builtins.clamp_int) { | ||||||
|                             auto& sig = builtin->Signature(); |                             auto& sig = builtin->Signature(); | ||||||
|                             if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) { |                             if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) { | ||||||
|                                 polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                                 polyfill = polyfills.GetOrCreate( | ||||||
|                                     return s.clampInteger(builtin->ReturnType()); |                                     builtin, [&] { return s.clampInteger(builtin->ReturnType()); }); | ||||||
|                                 }); |  | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kCountLeadingZeros: |                     case sem::BuiltinType::kCountLeadingZeros: | ||||||
|                         if (builtins.count_leading_zeros) { |                         if (builtins.count_leading_zeros) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate(builtin, [&] { | ||||||
|                                 return s.countLeadingZeros(builtin->ReturnType()); |                                 return s.countLeadingZeros(builtin->ReturnType()); | ||||||
|                             }); |                             }); | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kCountTrailingZeros: |                     case sem::BuiltinType::kCountTrailingZeros: | ||||||
|                         if (builtins.count_trailing_zeros) { |                         if (builtins.count_trailing_zeros) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate(builtin, [&] { | ||||||
|                                 return s.countTrailingZeros(builtin->ReturnType()); |                                 return s.countTrailingZeros(builtin->ReturnType()); | ||||||
|                             }); |                             }); | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kExtractBits: |                     case sem::BuiltinType::kExtractBits: | ||||||
|                         if (builtins.extract_bits != Level::kNone) { |                         if (builtins.extract_bits != Level::kNone) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 return s.extractBits(builtin->ReturnType()); |                                 builtin, [&] { return s.extractBits(builtin->ReturnType()); }); | ||||||
|                             }); |  | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kFirstLeadingBit: |                     case sem::BuiltinType::kFirstLeadingBit: | ||||||
|                         if (builtins.first_leading_bit) { |                         if (builtins.first_leading_bit) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 return s.firstLeadingBit(builtin->ReturnType()); |                                 builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); }); | ||||||
|                             }); |  | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kFirstTrailingBit: |                     case sem::BuiltinType::kFirstTrailingBit: | ||||||
|                         if (builtins.first_trailing_bit) { |                         if (builtins.first_trailing_bit) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 return s.firstTrailingBit(builtin->ReturnType()); |                                 builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); }); | ||||||
|                             }); |  | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kInsertBits: |                     case sem::BuiltinType::kInsertBits: | ||||||
|                         if (builtins.insert_bits != Level::kNone) { |                         if (builtins.insert_bits != Level::kNone) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 return s.insertBits(builtin->ReturnType()); |                                 builtin, [&] { return s.insertBits(builtin->ReturnType()); }); | ||||||
|                             }); |  | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kSaturate: |                     case sem::BuiltinType::kSaturate: | ||||||
|                         if (builtins.saturate) { |                         if (builtins.saturate) { | ||||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                             polyfill = polyfills.GetOrCreate( | ||||||
|                                 return s.saturate(builtin->ReturnType()); |                                 builtin, [&] { return s.saturate(builtin->ReturnType()); }); | ||||||
|                             }); |  | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
|                     case sem::BuiltinType::kTextureSampleBaseClampToEdge: |                     case sem::BuiltinType::kTextureSampleBaseClampToEdge: | ||||||
| @ -799,7 +706,7 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons | |||||||
|                             auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); |                             auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); | ||||||
|                             if (auto* stex = tex->Type()->As<sem::SampledTexture>()) { |                             if (auto* stex = tex->Type()->As<sem::SampledTexture>()) { | ||||||
|                                 if (stex->type()->Is<sem::F32>()) { |                                 if (stex->type()->Is<sem::F32>()) { | ||||||
|                                     polyfill = utils::GetOrCreate(polyfills, builtin, [&] { |                                     polyfill = polyfills.GetOrCreate(builtin, [&] { | ||||||
|                                         return s.textureSampleBaseClampToEdge_2d_f32(); |                                         return s.textureSampleBaseClampToEdge_2d_f32(); | ||||||
|                                     }); |                                     }); | ||||||
|                                 } |                                 } | ||||||
| @ -809,8 +716,8 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons | |||||||
|                     case sem::BuiltinType::kQuantizeToF16: |                     case sem::BuiltinType::kQuantizeToF16: | ||||||
|                         if (builtins.quantize_to_vec_f16) { |                         if (builtins.quantize_to_vec_f16) { | ||||||
|                             if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) { |                             if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) { | ||||||
|                                 polyfill = utils::GetOrCreate(polyfills, builtin, |                                 polyfill = polyfills.GetOrCreate( | ||||||
|                                                               [&] { return s.quantizeToF16(vec); }); |                                     builtin, [&] { return s.quantizeToF16(vec); }); | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                         break; |                         break; | ||||||
| @ -819,14 +726,20 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons | |||||||
|                         break; |                         break; | ||||||
|                 } |                 } | ||||||
|                 if (polyfill.IsValid()) { |                 if (polyfill.IsValid()) { | ||||||
|                     return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args)); |                     auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args)); | ||||||
|  |                     ctx.Replace(call->Declaration(), replacement); | ||||||
|  |                     made_changes = true; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         return nullptr; |     } | ||||||
|     }); | 
 | ||||||
|  |     if (!made_changes) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {} | BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {} | ||||||
|  | |||||||
| @ -87,21 +87,13 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> { | |||||||
|         const Builtins builtins; |         const Builtins builtins; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   protected: |   private: | ||||||
|     struct State; |     struct State; | ||||||
| 
 |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -1561,7 +1561,8 @@ fn f() { | |||||||
| TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) { | TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) { | ||||||
|     auto* src = R"( |     auto* src = R"( | ||||||
| fn f() { | fn f() { | ||||||
|   let r : i32 = insertBits(1234, 5678, 5u, 6u); |   let v = 1234i; | ||||||
|  |   let r : i32 = insertBits(v, 5678, 5u, 6u); | ||||||
| } | } | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
| @ -1975,10 +1976,6 @@ fn f() { | |||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     auto* expect = R"( |     auto* expect = R"( | ||||||
| @group(0) @binding(0) var t : texture_2d<f32>; |  | ||||||
| 
 |  | ||||||
| @group(0) @binding(1) var s : sampler; |  | ||||||
| 
 |  | ||||||
| fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> { | fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> { | ||||||
|   let dims = vec2<f32>(textureDimensions(t, 0)); |   let dims = vec2<f32>(textureDimensions(t, 0)); | ||||||
|   let half_texel = (vec2<f32>(0.5) / dims); |   let half_texel = (vec2<f32>(0.5) / dims); | ||||||
| @ -1986,6 +1983,10 @@ fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : v | |||||||
|   return textureSampleLevel(t, s, clamped, 0); |   return textureSampleLevel(t, s, clamped, 0); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @group(0) @binding(0) var t : texture_2d<f32>; | ||||||
|  | 
 | ||||||
|  | @group(0) @binding(1) var s : sampler; | ||||||
|  | 
 | ||||||
| fn f() { | fn f() { | ||||||
|   let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5)); |   let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5)); | ||||||
| } | } | ||||||
|  | |||||||
| @ -40,6 +40,19 @@ namespace tint::transform { | |||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* fn : program->AST().Functions()) { | ||||||
|  |         if (auto* sem_fn = program->Sem().Get(fn)) { | ||||||
|  |             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { | ||||||
|  |                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { | ||||||
|  |                     return true; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /// ArrayUsage describes a runtime array usage.
 | /// ArrayUsage describes a runtime array usage.
 | ||||||
| /// It is used as a key by the array_length_by_usage map.
 | /// It is used as a key by the array_length_by_usage map.
 | ||||||
| struct ArrayUsage { | struct ArrayUsage { | ||||||
| @ -73,21 +86,16 @@ const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSiz | |||||||
| CalculateArrayLength::CalculateArrayLength() = default; | CalculateArrayLength::CalculateArrayLength() = default; | ||||||
| CalculateArrayLength::~CalculateArrayLength() = default; | CalculateArrayLength::~CalculateArrayLength() = default; | ||||||
| 
 | 
 | ||||||
| bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult CalculateArrayLength::Apply(const Program* src, | ||||||
|     for (auto* fn : program->AST().Functions()) { |                                                    const DataMap&, | ||||||
|         if (auto* sem_fn = program->Sem().Get(fn)) { |                                                    DataMap&) const { | ||||||
|             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { |     if (!ShouldRun(src)) { | ||||||
|                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { |         return SkipTransform; | ||||||
|                     return true; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |     ProgramBuilder b; | ||||||
|     auto& sem = ctx.src->Sem(); |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |     auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|     // get_buffer_size_intrinsic() emits the function decorated with
 |     // get_buffer_size_intrinsic() emits the function decorated with
 | ||||||
|     // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
 |     // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
 | ||||||
| @ -95,24 +103,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|     std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics; |     std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics; | ||||||
|     auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) { |     auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) { | ||||||
|         return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { |         return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { | ||||||
|             auto name = ctx.dst->Sym(); |             auto name = b.Sym(); | ||||||
|             auto* type = CreateASTTypeFor(ctx, buffer_type); |             auto* type = CreateASTTypeFor(ctx, buffer_type); | ||||||
|             auto* disable_validation = |             auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter); | ||||||
|                 ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter); |             b.AST().AddFunction(b.create<ast::Function>( | ||||||
|             ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>( |  | ||||||
|                 name, |                 name, | ||||||
|                 utils::Vector{ |                 utils::Vector{ | ||||||
|                     ctx.dst->Param("buffer", |                     b.Param("buffer", | ||||||
|                                    ctx.dst->ty.pointer(type, buffer_type->AddressSpace(), |                             b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()), | ||||||
|                                                        buffer_type->Access()), |                             utils::Vector{disable_validation}), | ||||||
|                                    utils::Vector{disable_validation}), |                     b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)), | ||||||
|                     ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(), |  | ||||||
|                                                                  ast::AddressSpace::kFunction)), |  | ||||||
|                 }, |                 }, | ||||||
|                 ctx.dst->ty.void_(), nullptr, |                 b.ty.void_(), nullptr, | ||||||
|                 utils::Vector{ |                 utils::Vector{ | ||||||
|                     ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID(), |                     b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()), | ||||||
|                                                                     ctx.dst->AllocateNodeID()), |  | ||||||
|                 }, |                 }, | ||||||
|                 utils::Empty)); |                 utils::Empty)); | ||||||
| 
 | 
 | ||||||
| @ -123,7 +127,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|     std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage; |     std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage; | ||||||
| 
 | 
 | ||||||
|     // Find all the arrayLength() calls...
 |     // Find all the arrayLength() calls...
 | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         if (auto* call_expr = node->As<ast::CallExpression>()) { |         if (auto* call_expr = node->As<ast::CallExpression>()) { | ||||||
|             auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>(); |             auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||||
|             if (auto* builtin = call->Target()->As<sem::Builtin>()) { |             if (auto* builtin = call->Target()->As<sem::Builtin>()) { | ||||||
| @ -149,7 +153,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|                     auto* arg = call_expr->args[0]; |                     auto* arg = call_expr->args[0]; | ||||||
|                     auto* address_of = arg->As<ast::UnaryOpExpression>(); |                     auto* address_of = arg->As<ast::UnaryOpExpression>(); | ||||||
|                     if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { |                     if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { | ||||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) |                         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                             << "arrayLength() expected address-of, got " << arg->TypeInfo().name; |                             << "arrayLength() expected address-of, got " << arg->TypeInfo().name; | ||||||
|                     } |                     } | ||||||
|                     auto* storage_buffer_expr = address_of->expr; |                     auto* storage_buffer_expr = address_of->expr; | ||||||
| @ -158,7 +162,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|                     } |                     } | ||||||
|                     auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); |                     auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); | ||||||
|                     if (!storage_buffer_sem) { |                     if (!storage_buffer_sem) { | ||||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) |                         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                             << "expected form of arrayLength argument to be &array_var or " |                             << "expected form of arrayLength argument to be &array_var or " | ||||||
|                                "&struct_var.array_member"; |                                "&struct_var.array_member"; | ||||||
|                         break; |                         break; | ||||||
| @ -179,25 +183,24 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
| 
 | 
 | ||||||
|                             // Construct the variable that'll hold the result of
 |                             // Construct the variable that'll hold the result of
 | ||||||
|                             // RWByteAddressBuffer.GetDimensions()
 |                             // RWByteAddressBuffer.GetDimensions()
 | ||||||
|                             auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var( |                             auto* buffer_size_result = | ||||||
|                                 ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u))); |                                 b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u))); | ||||||
| 
 | 
 | ||||||
|                             // Call storage_buffer.GetDimensions(&buffer_size_result)
 |                             // Call storage_buffer.GetDimensions(&buffer_size_result)
 | ||||||
|                             auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call( |                             auto* call_get_dims = b.CallStmt(b.Call( | ||||||
|                                 // BufferSizeIntrinsic(X, ARGS...) is
 |                                 // BufferSizeIntrinsic(X, ARGS...) is
 | ||||||
|                                 // translated to:
 |                                 // translated to:
 | ||||||
|                                 //  X.GetDimensions(ARGS..) by the writer
 |                                 //  X.GetDimensions(ARGS..) by the writer
 | ||||||
|                                 buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)), |                                 buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)), | ||||||
|                                 ctx.dst->AddressOf( |                                 b.AddressOf(b.Expr(buffer_size_result->variable->symbol)))); | ||||||
|                                     ctx.dst->Expr(buffer_size_result->variable->symbol)))); |  | ||||||
| 
 | 
 | ||||||
|                             // Calculate actual array length
 |                             // Calculate actual array length
 | ||||||
|                             //                total_storage_buffer_size - array_offset
 |                             //                total_storage_buffer_size - array_offset
 | ||||||
|                             // array_length = ----------------------------------------
 |                             // array_length = ----------------------------------------
 | ||||||
|                             //                             array_stride
 |                             //                             array_stride
 | ||||||
|                             auto name = ctx.dst->Sym(); |                             auto name = b.Sym(); | ||||||
|                             const ast::Expression* total_size = |                             const ast::Expression* total_size = | ||||||
|                                 ctx.dst->Expr(buffer_size_result->variable); |                                 b.Expr(buffer_size_result->variable); | ||||||
| 
 | 
 | ||||||
|                             const sem::Array* array_type = Switch( |                             const sem::Array* array_type = Switch( | ||||||
|                                 storage_buffer_type->StoreType(), |                                 storage_buffer_type->StoreType(), | ||||||
| @ -205,23 +208,21 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|                                     // The variable is a struct, so subtract the byte offset of
 |                                     // The variable is a struct, so subtract the byte offset of
 | ||||||
|                                     // the array member.
 |                                     // the array member.
 | ||||||
|                                     auto* array_member_sem = str->Members().back(); |                                     auto* array_member_sem = str->Members().back(); | ||||||
|                                     total_size = |                                     total_size = b.Sub(total_size, u32(array_member_sem->Offset())); | ||||||
|                                         ctx.dst->Sub(total_size, u32(array_member_sem->Offset())); |  | ||||||
|                                     return array_member_sem->Type()->As<sem::Array>(); |                                     return array_member_sem->Type()->As<sem::Array>(); | ||||||
|                                 }, |                                 }, | ||||||
|                                 [&](const sem::Array* arr) { return arr; }); |                                 [&](const sem::Array* arr) { return arr; }); | ||||||
| 
 | 
 | ||||||
|                             if (!array_type) { |                             if (!array_type) { | ||||||
|                                 TINT_ICE(Transform, ctx.dst->Diagnostics()) |                                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                                     << "expected form of arrayLength argument to be " |                                     << "expected form of arrayLength argument to be " | ||||||
|                                        "&array_var or &struct_var.array_member"; |                                        "&array_var or &struct_var.array_member"; | ||||||
|                                 return name; |                                 return name; | ||||||
|                             } |                             } | ||||||
| 
 | 
 | ||||||
|                             uint32_t array_stride = array_type->Size(); |                             uint32_t array_stride = array_type->Size(); | ||||||
|                             auto* array_length_var = ctx.dst->Decl( |                             auto* array_length_var = b.Decl( | ||||||
|                                 ctx.dst->Let(name, ctx.dst->ty.u32(), |                                 b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride)))); | ||||||
|                                              ctx.dst->Div(total_size, u32(array_stride)))); |  | ||||||
| 
 | 
 | ||||||
|                             // Insert the array length calculations at the top of the block
 |                             // Insert the array length calculations at the top of the block
 | ||||||
|                             ctx.InsertBefore(block->statements, block->statements[0], |                             ctx.InsertBefore(block->statements, block->statements[0], | ||||||
| @ -234,13 +235,14 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|                         }); |                         }); | ||||||
| 
 | 
 | ||||||
|                     // Replace the call to arrayLength() with the array length variable
 |                     // Replace the call to arrayLength() with the array length variable
 | ||||||
|                     ctx.Replace(call_expr, ctx.dst->Expr(array_length)); |                     ctx.Replace(call_expr, b.Expr(array_length)); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -59,19 +59,10 @@ class CalculateArrayLength final : public Castable<CalculateArrayLength, Transfo | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~CalculateArrayLength() override; |     ~CalculateArrayLength() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -123,7 +123,7 @@ bool HasSampleMask(utils::VectorRef<const ast::Attribute*> attrs) { | |||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| /// State holds the current transform state for a single entry point.
 | /// PIMPL state for the transform
 | ||||||
| struct CanonicalizeEntryPointIO::State { | struct CanonicalizeEntryPointIO::State { | ||||||
|     /// OutputValue represents a shader result that the wrapper function produces.
 |     /// OutputValue represents a shader result that the wrapper function produces.
 | ||||||
|     struct OutputValue { |     struct OutputValue { | ||||||
| @ -770,17 +770,22 @@ struct CanonicalizeEntryPointIO::State { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src, | ||||||
|  |                                                        const DataMap& inputs, | ||||||
|  |                                                        DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     auto* cfg = inputs.Get<Config>(); |     auto* cfg = inputs.Get<Config>(); | ||||||
|     if (cfg == nullptr) { |     if (cfg == nullptr) { | ||||||
|         ctx.dst->Diagnostics().add_error( |         b.Diagnostics().add_error(diag::System::Transform, | ||||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||||
|         return; |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Remove entry point IO attributes from struct declarations.
 |     // Remove entry point IO attributes from struct declarations.
 | ||||||
|     // New structures will be created for each entry point, as necessary.
 |     // New structures will be created for each entry point, as necessary.
 | ||||||
|     for (auto* ty : ctx.src->AST().TypeDecls()) { |     for (auto* ty : src->AST().TypeDecls()) { | ||||||
|         if (auto* struct_ty = ty->As<ast::Struct>()) { |         if (auto* struct_ty = ty->As<ast::Struct>()) { | ||||||
|             for (auto* member : struct_ty->members) { |             for (auto* member : struct_ty->members) { | ||||||
|                 for (auto* attr : member->attributes) { |                 for (auto* attr : member->attributes) { | ||||||
| @ -792,7 +797,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     for (auto* func_ast : ctx.src->AST().Functions()) { |     for (auto* func_ast : src->AST().Functions()) { | ||||||
|         if (!func_ast->IsEntryPoint()) { |         if (!func_ast->IsEntryPoint()) { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
| @ -802,6 +807,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| CanonicalizeEntryPointIO::Config::Config(ShaderStyle style, | CanonicalizeEntryPointIO::Config::Config(ShaderStyle style, | ||||||
|  | |||||||
| @ -127,15 +127,12 @@ class CanonicalizeEntryPointIO final : public Castable<CanonicalizeEntryPointIO, | |||||||
|     CanonicalizeEntryPointIO(); |     CanonicalizeEntryPointIO(); | ||||||
|     ~CanonicalizeEntryPointIO() override; |     ~CanonicalizeEntryPointIO() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|  |   private: | ||||||
|     struct State; |     struct State; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -14,7 +14,7 @@ | |||||||
| 
 | 
 | ||||||
| #include "src/tint/transform/clamp_frag_depth.h" | #include "src/tint/transform/clamp_frag_depth.h" | ||||||
| 
 | 
 | ||||||
|  #include <utility> | #include <utility> | ||||||
| 
 | 
 | ||||||
| #include "src/tint/ast/attribute.h" | #include "src/tint/ast/attribute.h" | ||||||
| #include "src/tint/ast/builtin_attribute.h" | #include "src/tint/ast/builtin_attribute.h" | ||||||
| @ -64,12 +64,7 @@ bool ReturnsFragDepthInStruct(const sem::Info& sem, const ast::Function* fn) { | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // anonymous namespace
 | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| ClampFragDepth::ClampFragDepth() = default; |  | ||||||
| ClampFragDepth::~ClampFragDepth() = default; |  | ||||||
| 
 |  | ||||||
| bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     auto& sem = program->Sem(); |     auto& sem = program->Sem(); | ||||||
| 
 | 
 | ||||||
|     for (auto* fn : program->AST().Functions()) { |     for (auto* fn : program->AST().Functions()) { | ||||||
| @ -82,22 +77,33 @@ bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const { | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // anonymous namespace
 | ||||||
|  | 
 | ||||||
|  | ClampFragDepth::ClampFragDepth() = default; | ||||||
|  | ClampFragDepth::~ClampFragDepth() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     // Abort on any use of push constants in the module.
 |     // Abort on any use of push constants in the module.
 | ||||||
|     for (auto* global : ctx.src->AST().GlobalVariables()) { |     for (auto* global : src->AST().GlobalVariables()) { | ||||||
|         if (auto* var = global->As<ast::Var>()) { |         if (auto* var = global->As<ast::Var>()) { | ||||||
|             if (var->declared_address_space == ast::AddressSpace::kPushConstant) { |             if (var->declared_address_space == ast::AddressSpace::kPushConstant) { | ||||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                     << "ClampFragDepth doesn't know how to handle module that already use push " |                     << "ClampFragDepth doesn't know how to handle module that already use push " | ||||||
|                        "constants."; |                        "constants."; | ||||||
|                 return; |                 return Program(std::move(b)); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     auto& b = *ctx.dst; |     if (!ShouldRun(src)) { | ||||||
|     auto& sem = ctx.src->Sem(); |         return SkipTransform; | ||||||
|     auto& sym = ctx.src->Symbols(); |     } | ||||||
|  | 
 | ||||||
|  |     auto& sem = src->Sem(); | ||||||
|  |     auto& sym = src->Symbols(); | ||||||
| 
 | 
 | ||||||
|     // At least one entry-point needs clamping. Add the following to the module:
 |     // At least one entry-point needs clamping. Add the following to the module:
 | ||||||
|     //
 |     //
 | ||||||
| @ -197,6 +203,7 @@ void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -61,19 +61,10 @@ class ClampFragDepth final : public Castable<ClampFragDepth, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~ClampFragDepth() override; |     ~ClampFragDepth() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -47,10 +47,14 @@ CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map, | |||||||
| CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default; | CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default; | ||||||
| CombineSamplers::BindingInfo::~BindingInfo() = default; | CombineSamplers::BindingInfo::~BindingInfo() = default; | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for the CombineSamplers transform
 | /// PIMPL state for the transform
 | ||||||
| struct CombineSamplers::State { | struct CombineSamplers::State { | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
| 
 | 
 | ||||||
|     /// The binding info
 |     /// The binding info
 | ||||||
|     const BindingInfo* binding_info; |     const BindingInfo* binding_info; | ||||||
| @ -88,9 +92,9 @@ struct CombineSamplers::State { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param context the clone context
 |     /// @param program the source program
 | ||||||
|     /// @param info the binding map information
 |     /// @param info the binding map information
 | ||||||
|     State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {} |     State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {} | ||||||
| 
 | 
 | ||||||
|     /// Creates a combined sampler global variables.
 |     /// Creates a combined sampler global variables.
 | ||||||
|     /// (Note this is actually a Texture node at the AST level, but it will be
 |     /// (Note this is actually a Texture node at the AST level, but it will be
 | ||||||
| @ -145,8 +149,9 @@ struct CombineSamplers::State { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Performs the transformation
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|         auto& sem = ctx.src->Sem(); |         auto& sem = ctx.src->Sem(); | ||||||
| 
 | 
 | ||||||
|         // Remove all texture and sampler global variables. These will be replaced
 |         // Remove all texture and sampler global variables. These will be replaced
 | ||||||
| @ -169,14 +174,14 @@ struct CombineSamplers::State { | |||||||
| 
 | 
 | ||||||
|         // Rewrite all function signatures to use combined samplers, and remove
 |         // Rewrite all function signatures to use combined samplers, and remove
 | ||||||
|         // separate textures & samplers. Create new combined globals where found.
 |         // separate textures & samplers. Create new combined globals where found.
 | ||||||
|         ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* { |         ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* { | ||||||
|             if (auto* func = sem.Get(src)) { |             if (auto* fn = sem.Get(ast_fn)) { | ||||||
|                 auto pairs = func->TextureSamplerPairs(); |                 auto pairs = fn->TextureSamplerPairs(); | ||||||
|                 if (pairs.IsEmpty()) { |                 if (pairs.IsEmpty()) { | ||||||
|                     return nullptr; |                     return nullptr; | ||||||
|                 } |                 } | ||||||
|                 utils::Vector<const ast::Parameter*, 8> params; |                 utils::Vector<const ast::Parameter*, 8> params; | ||||||
|                 for (auto pair : func->TextureSamplerPairs()) { |                 for (auto pair : fn->TextureSamplerPairs()) { | ||||||
|                     const sem::Variable* texture_var = pair.first; |                     const sem::Variable* texture_var = pair.first; | ||||||
|                     const sem::Variable* sampler_var = pair.second; |                     const sem::Variable* sampler_var = pair.second; | ||||||
|                     std::string name = |                     std::string name = | ||||||
| @ -197,23 +202,23 @@ struct CombineSamplers::State { | |||||||
|                         auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var); |                         auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var); | ||||||
|                         auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type); |                         auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type); | ||||||
|                         params.Push(var); |                         params.Push(var); | ||||||
|                         function_combined_texture_samplers_[func][pair] = var; |                         function_combined_texture_samplers_[fn][pair] = var; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 // Filter out separate textures and samplers from the original
 |                 // Filter out separate textures and samplers from the original
 | ||||||
|                 // function signature.
 |                 // function signature.
 | ||||||
|                 for (auto* var : src->params) { |                 for (auto* param : fn->Parameters()) { | ||||||
|                     if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) { |                     if (!param->Type()->IsAnyOf<sem::Texture, sem::Sampler>()) { | ||||||
|                         params.Push(ctx.Clone(var)); |                         params.Push(ctx.Clone(param->Declaration())); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 // Create a new function signature that differs only in the parameter
 |                 // Create a new function signature that differs only in the parameter
 | ||||||
|                 // list.
 |                 // list.
 | ||||||
|                 auto symbol = ctx.Clone(src->symbol); |                 auto symbol = ctx.Clone(ast_fn->symbol); | ||||||
|                 auto* return_type = ctx.Clone(src->return_type); |                 auto* return_type = ctx.Clone(ast_fn->return_type); | ||||||
|                 auto* body = ctx.Clone(src->body); |                 auto* body = ctx.Clone(ast_fn->body); | ||||||
|                 auto attributes = ctx.Clone(src->attributes); |                 auto attributes = ctx.Clone(ast_fn->attributes); | ||||||
|                 auto return_type_attributes = ctx.Clone(src->return_type_attributes); |                 auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes); | ||||||
|                 return ctx.dst->create<ast::Function>(symbol, params, return_type, body, |                 return ctx.dst->create<ast::Function>(symbol, params, return_type, body, | ||||||
|                                                       std::move(attributes), |                                                       std::move(attributes), | ||||||
|                                                       std::move(return_type_attributes)); |                                                       std::move(return_type_attributes)); | ||||||
| @ -327,6 +332,7 @@ struct CombineSamplers::State { | |||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| @ -334,15 +340,18 @@ CombineSamplers::CombineSamplers() = default; | |||||||
| 
 | 
 | ||||||
| CombineSamplers::~CombineSamplers() = default; | CombineSamplers::~CombineSamplers() = default; | ||||||
| 
 | 
 | ||||||
| void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | Transform::ApplyResult CombineSamplers::Apply(const Program* src, | ||||||
|  |                                               const DataMap& inputs, | ||||||
|  |                                               DataMap&) const { | ||||||
|     auto* binding_info = inputs.Get<BindingInfo>(); |     auto* binding_info = inputs.Get<BindingInfo>(); | ||||||
|     if (!binding_info) { |     if (!binding_info) { | ||||||
|         ctx.dst->Diagnostics().add_error( |         ProgramBuilder b; | ||||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |         b.Diagnostics().add_error(diag::System::Transform, | ||||||
|         return; |                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     State(ctx, binding_info).Run(); |     return State(src, binding_info).Run(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -88,17 +88,13 @@ class CombineSamplers final : public Castable<CombineSamplers, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~CombineSamplers() override; |     ~CombineSamplers() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// The PIMPL state for this transform
 |     ApplyResult Apply(const Program* program, | ||||||
|     struct State; |                       const DataMap& inputs, | ||||||
|  |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |   private: | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |     struct State; | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -47,6 +47,18 @@ namespace tint::transform { | |||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* decl : program->AST().GlobalDeclarations()) { | ||||||
|  |         if (auto* var = program->Sem().Get<sem::Variable>(decl)) { | ||||||
|  |             if (var->AddressSpace() == ast::AddressSpace::kStorage || | ||||||
|  |                 var->AddressSpace() == ast::AddressSpace::kUniform) { | ||||||
|  |                 return true; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /// Offset is a simple ast::Expression builder interface, used to build byte
 | /// Offset is a simple ast::Expression builder interface, used to build byte
 | ||||||
| /// offsets for storage and uniform buffer accesses.
 | /// offsets for storage and uniform buffer accesses.
 | ||||||
| struct Offset : Castable<Offset> { | struct Offset : Castable<Offset> { | ||||||
| @ -291,7 +303,7 @@ struct Store { | |||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| /// State holds the current transform state
 | /// PIMPL state for the transform
 | ||||||
| struct DecomposeMemoryAccess::State { | struct DecomposeMemoryAccess::State { | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext& ctx; | ||||||
| @ -477,7 +489,7 @@ struct DecomposeMemoryAccess::State { | |||||||
|                         // * Override-expression counts can only be applied to workgroup arrays, and
 |                         // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||||
|                         //   this method only handles storage and uniform.
 |                         //   this method only handles storage and uniform.
 | ||||||
|                         // * Runtime-sized arrays are not loadable.
 |                         // * Runtime-sized arrays are not loadable.
 | ||||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) |                         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                             << "unexpected non-constant array count"; |                             << "unexpected non-constant array count"; | ||||||
|                         arr_cnt = 1; |                         arr_cnt = 1; | ||||||
|                     } |                     } | ||||||
| @ -578,7 +590,7 @@ struct DecomposeMemoryAccess::State { | |||||||
|                                 // * Override-expression counts can only be applied to workgroup
 |                                 // * Override-expression counts can only be applied to workgroup
 | ||||||
|                                 //   arrays, and this method only handles storage and uniform.
 |                                 //   arrays, and this method only handles storage and uniform.
 | ||||||
|                                 // * Runtime-sized arrays are not storable.
 |                                 // * Runtime-sized arrays are not storable.
 | ||||||
|                                 TINT_ICE(Transform, ctx.dst->Diagnostics()) |                                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                                     << "unexpected non-constant array count"; |                                     << "unexpected non-constant array count"; | ||||||
|                                 arr_cnt = 1; |                                 arr_cnt = 1; | ||||||
|                             } |                             } | ||||||
| @ -808,21 +820,16 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const { | |||||||
| DecomposeMemoryAccess::DecomposeMemoryAccess() = default; | DecomposeMemoryAccess::DecomposeMemoryAccess() = default; | ||||||
| DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; | DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; | ||||||
| 
 | 
 | ||||||
| bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, | ||||||
|     for (auto* decl : program->AST().GlobalDeclarations()) { |                                                     const DataMap&, | ||||||
|         if (auto* var = program->Sem().Get<sem::Variable>(decl)) { |                                                     DataMap&) const { | ||||||
|             if (var->AddressSpace() == ast::AddressSpace::kStorage || |     if (!ShouldRun(src)) { | ||||||
|                 var->AddressSpace() == ast::AddressSpace::kUniform) { |         return SkipTransform; | ||||||
|                 return true; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     auto& sem = ctx.src->Sem(); |  | ||||||
| 
 | 
 | ||||||
|  |     auto& sem = src->Sem(); | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|     State state(ctx); |     State state(ctx); | ||||||
| 
 | 
 | ||||||
|     // Scan the AST nodes for storage and uniform buffer accesses. Complex
 |     // Scan the AST nodes for storage and uniform buffer accesses. Complex
 | ||||||
| @ -833,7 +840,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con | |||||||
|     // Inner-most expression nodes are guaranteed to be visited first because AST
 |     // Inner-most expression nodes are guaranteed to be visited first because AST
 | ||||||
|     // nodes are fully immutable and require their children to be constructed
 |     // nodes are fully immutable and require their children to be constructed
 | ||||||
|     // first so their pointer can be passed to the parent's initializer.
 |     // first so their pointer can be passed to the parent's initializer.
 | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         if (auto* ident = node->As<ast::IdentifierExpression>()) { |         if (auto* ident = node->As<ast::IdentifierExpression>()) { | ||||||
|             // X
 |             // X
 | ||||||
|             if (auto* var = sem.Get<sem::VariableUser>(ident)) { |             if (auto* var = sem.Get<sem::VariableUser>(ident)) { | ||||||
| @ -1001,6 +1008,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -108,20 +108,12 @@ class DecomposeMemoryAccess final : public Castable<DecomposeMemoryAccess, Trans | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~DecomposeMemoryAccess() override; |     ~DecomposeMemoryAccess() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|  |   private: | ||||||
|     struct State; |     struct State; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,13 +34,7 @@ namespace { | |||||||
| 
 | 
 | ||||||
| using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>; | using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>; | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| DecomposeStridedArray::DecomposeStridedArray() = default; |  | ||||||
| 
 |  | ||||||
| DecomposeStridedArray::~DecomposeStridedArray() = default; |  | ||||||
| 
 |  | ||||||
| bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|         if (auto* ast = node->As<ast::Array>()) { |         if (auto* ast = node->As<ast::Array>()) { | ||||||
|             if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) { |             if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) { | ||||||
| @ -51,8 +45,22 @@ bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) co | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // namespace
 | ||||||
|     const auto& sem = ctx.src->Sem(); | 
 | ||||||
|  | DecomposeStridedArray::DecomposeStridedArray() = default; | ||||||
|  | 
 | ||||||
|  | DecomposeStridedArray::~DecomposeStridedArray() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src, | ||||||
|  |                                                     const DataMap&, | ||||||
|  |                                                     DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |     const auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|     static constexpr const char* kMemberName = "el"; |     static constexpr const char* kMemberName = "el"; | ||||||
| 
 | 
 | ||||||
| @ -69,23 +77,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con | |||||||
|         if (auto* arr = sem.Get(ast)) { |         if (auto* arr = sem.Get(ast)) { | ||||||
|             if (!arr->IsStrideImplicit()) { |             if (!arr->IsStrideImplicit()) { | ||||||
|                 auto el_ty = utils::GetOrCreate(decomposed, arr, [&] { |                 auto el_ty = utils::GetOrCreate(decomposed, arr, [&] { | ||||||
|                     auto name = ctx.dst->Symbols().New("strided_arr"); |                     auto name = b.Symbols().New("strided_arr"); | ||||||
|                     auto* member_ty = ctx.Clone(ast->type); |                     auto* member_ty = ctx.Clone(ast->type); | ||||||
|                     auto* member = ctx.dst->Member(kMemberName, member_ty, |                     auto* member = b.Member(kMemberName, member_ty, | ||||||
|                                                    utils::Vector{ |                                             utils::Vector{ | ||||||
|                                                        ctx.dst->MemberSize(AInt(arr->Stride())), |                                                 b.MemberSize(AInt(arr->Stride())), | ||||||
|                                                    }); |                                             }); | ||||||
|                     ctx.dst->Structure(name, utils::Vector{member}); |                     b.Structure(name, utils::Vector{member}); | ||||||
|                     return name; |                     return name; | ||||||
|                 }); |                 }); | ||||||
|                 auto* count = ctx.Clone(ast->count); |                 auto* count = ctx.Clone(ast->count); | ||||||
|                 return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count); |                 return b.ty.array(b.ty.type_name(el_ty), count); | ||||||
|             } |             } | ||||||
|             if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) { |             if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) { | ||||||
|                 // Strip the @stride attribute
 |                 // Strip the @stride attribute
 | ||||||
|                 auto* ty = ctx.Clone(ast->type); |                 auto* ty = ctx.Clone(ast->type); | ||||||
|                 auto* count = ctx.Clone(ast->count); |                 auto* count = ctx.Clone(ast->count); | ||||||
|                 return ctx.dst->ty.array(ty, count); |                 return b.ty.array(ty, count); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         return nullptr; |         return nullptr; | ||||||
| @ -96,11 +104,11 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con | |||||||
|     // to insert an additional member accessor for the single structure field.
 |     // to insert an additional member accessor for the single structure field.
 | ||||||
|     // Example: `arr[i]` -> `arr[i].el`
 |     // Example: `arr[i]` -> `arr[i].el`
 | ||||||
|     ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* { |     ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* { | ||||||
|         if (auto* ty = ctx.src->TypeOf(idx->object)) { |         if (auto* ty = src->TypeOf(idx->object)) { | ||||||
|             if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) { |             if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) { | ||||||
|                 if (!arr->IsStrideImplicit()) { |                 if (!arr->IsStrideImplicit()) { | ||||||
|                     auto* expr = ctx.CloneWithoutTransform(idx); |                     auto* expr = ctx.CloneWithoutTransform(idx); | ||||||
|                     return ctx.dst->MemberAccessor(expr, kMemberName); |                     return b.MemberAccessor(expr, kMemberName); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @ -136,21 +144,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con | |||||||
|                         if (auto it = decomposed.find(arr); it != decomposed.end()) { |                         if (auto it = decomposed.find(arr); it != decomposed.end()) { | ||||||
|                             args.Reserve(expr->args.Length()); |                             args.Reserve(expr->args.Length()); | ||||||
|                             for (auto* arg : expr->args) { |                             for (auto* arg : expr->args) { | ||||||
|                                 args.Push(ctx.dst->Call(it->second, ctx.Clone(arg))); |                                 args.Push(b.Call(it->second, ctx.Clone(arg))); | ||||||
|                             } |                             } | ||||||
|                         } else { |                         } else { | ||||||
|                             args = ctx.Clone(expr->args); |                             args = ctx.Clone(expr->args); | ||||||
|                         } |                         } | ||||||
| 
 | 
 | ||||||
|                         return target.type ? ctx.dst->Construct(target.type, std::move(args)) |                         return target.type ? b.Construct(target.type, std::move(args)) | ||||||
|                                            : ctx.dst->Call(target.name, std::move(args)); |                                            : b.Call(target.name, std::move(args)); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     }); |     }); | ||||||
|  | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -35,19 +35,10 @@ class DecomposeStridedArray final : public Castable<DecomposeStridedArray, Trans | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~DecomposeStridedArray() override; |     ~DecomposeStridedArray() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -53,24 +53,25 @@ struct MatrixInfo { | |||||||
|     }; |     }; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| /// Return type of the callback function of GatherCustomStrideMatrixMembers
 | }  // namespace
 | ||||||
| enum GatherResult { kContinue, kStop }; |  | ||||||
| 
 | 
 | ||||||
| /// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
 | DecomposeStridedMatrix::DecomposeStridedMatrix() = default; | ||||||
| /// storage and uniform structs, which are of a matrix type, and have a custom
 | 
 | ||||||
| /// matrix stride attribute. For each matrix member found, `callback` is called.
 | DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; | ||||||
| /// `callback` is a function with the signature:
 | 
 | ||||||
| ///      GatherResult(const sem::StructMember* member,
 | Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, | ||||||
| ///                   sem::Matrix* matrix,
 |                                                      const DataMap&, | ||||||
| ///                   uint32_t stride)
 |                                                      DataMap&) const { | ||||||
| /// If `callback` return GatherResult::kStop, then the scanning will immediately
 |     ProgramBuilder b; | ||||||
| /// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
 |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
| /// scanning will continue.
 | 
 | ||||||
| template <typename F> |     // Scan the program for all storage and uniform structure matrix members with
 | ||||||
| void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { |     // a custom stride attribute. Replace these matrices with an equivalent array,
 | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     // and populate the `decomposed` map with the members that have been replaced.
 | ||||||
|  |     utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed; | ||||||
|  |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         if (auto* str = node->As<ast::Struct>()) { |         if (auto* str = node->As<ast::Struct>()) { | ||||||
|             auto* str_ty = program->Sem().Get(str); |             auto* str_ty = src->Sem().Get(str); | ||||||
|             if (!str_ty->UsedAs(ast::AddressSpace::kUniform) && |             if (!str_ty->UsedAs(ast::AddressSpace::kUniform) && | ||||||
|                 !str_ty->UsedAs(ast::AddressSpace::kStorage)) { |                 !str_ty->UsedAs(ast::AddressSpace::kStorage)) { | ||||||
|                 continue; |                 continue; | ||||||
| @ -89,46 +90,20 @@ void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { | |||||||
|                 if (matrix->ColumnStride() == stride) { |                 if (matrix->ColumnStride() == stride) { | ||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
|                 if (callback(member, matrix, stride) == GatherResult::kStop) { |                 // We've got ourselves a struct member of a matrix type with a custom
 | ||||||
|                     return; |                 // stride. Replace this with an array of column vectors.
 | ||||||
|                 } |                 MatrixInfo info{stride, matrix}; | ||||||
|  |                 auto* replacement = | ||||||
|  |                     b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); | ||||||
|  |                 ctx.Replace(member->Declaration(), replacement); | ||||||
|  |                 decomposed.Add(member->Declaration(), info); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| }  // namespace
 |     if (decomposed.IsEmpty()) { | ||||||
| 
 |         return SkipTransform; | ||||||
| DecomposeStridedMatrix::DecomposeStridedMatrix() = default; |     } | ||||||
| 
 |  | ||||||
| DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; |  | ||||||
| 
 |  | ||||||
| bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     bool should_run = false; |  | ||||||
|     GatherCustomStrideMatrixMembers(program, |  | ||||||
|                                     [&](const sem::StructMember*, const sem::Matrix*, uint32_t) { |  | ||||||
|                                         should_run = true; |  | ||||||
|                                         return GatherResult::kStop; |  | ||||||
|                                     }); |  | ||||||
|     return should_run; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     // Scan the program for all storage and uniform structure matrix members with
 |  | ||||||
|     // a custom stride attribute. Replace these matrices with an equivalent array,
 |  | ||||||
|     // and populate the `decomposed` map with the members that have been replaced.
 |  | ||||||
|     std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed; |  | ||||||
|     GatherCustomStrideMatrixMembers( |  | ||||||
|         ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) { |  | ||||||
|             // We've got ourselves a struct member of a matrix type with a custom
 |  | ||||||
|             // stride. Replace this with an array of column vectors.
 |  | ||||||
|             MatrixInfo info{stride, matrix}; |  | ||||||
|             auto* replacement = |  | ||||||
|                 ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); |  | ||||||
|             ctx.Replace(member->Declaration(), replacement); |  | ||||||
|             decomposed.emplace(member->Declaration(), info); |  | ||||||
|             return GatherResult::kContinue; |  | ||||||
|         }); |  | ||||||
| 
 | 
 | ||||||
|     // For all expressions where a single matrix column vector was indexed, we can
 |     // For all expressions where a single matrix column vector was indexed, we can
 | ||||||
|     // preserve these without calling conversion functions.
 |     // preserve these without calling conversion functions.
 | ||||||
| @ -136,12 +111,11 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co | |||||||
|     //   ssbo.mat[2] -> ssbo.mat[2]
 |     //   ssbo.mat[2] -> ssbo.mat[2]
 | ||||||
|     ctx.ReplaceAll( |     ctx.ReplaceAll( | ||||||
|         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { |         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { | ||||||
|             if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) { |             if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) { | ||||||
|                 auto it = decomposed.find(access->Member()->Declaration()); |                 if (decomposed.Contains(access->Member()->Declaration())) { | ||||||
|                 if (it != decomposed.end()) { |  | ||||||
|                     auto* obj = ctx.CloneWithoutTransform(expr->object); |                     auto* obj = ctx.CloneWithoutTransform(expr->object); | ||||||
|                     auto* idx = ctx.Clone(expr->index); |                     auto* idx = ctx.Clone(expr->index); | ||||||
|                     return ctx.dst->IndexAccessor(obj, idx); |                     return b.IndexAccessor(obj, idx); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             return nullptr; |             return nullptr; | ||||||
| @ -154,39 +128,36 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co | |||||||
|     //   ssbo.mat = mat_to_arr(m)
 |     //   ssbo.mat = mat_to_arr(m)
 | ||||||
|     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr; |     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr; | ||||||
|     ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { |     ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { | ||||||
|         if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { |         if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { | ||||||
|             auto it = decomposed.find(access->Member()->Declaration()); |             if (auto* info = decomposed.Find(access->Member()->Declaration())) { | ||||||
|             if (it == decomposed.end()) { |                 auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { | ||||||
|                 return nullptr; |                     auto name = | ||||||
|  |                         b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" + | ||||||
|  |                                         std::to_string(info->matrix->rows()) + "_stride_" + | ||||||
|  |                                         std::to_string(info->stride) + "_to_arr"); | ||||||
|  | 
 | ||||||
|  |                     auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; | ||||||
|  |                     auto array = [&] { return info->array(ctx.dst); }; | ||||||
|  | 
 | ||||||
|  |                     auto mat = b.Sym("m"); | ||||||
|  |                     utils::Vector<const ast::Expression*, 4> columns; | ||||||
|  |                     for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { | ||||||
|  |                         columns.Push(b.IndexAccessor(mat, u32(i))); | ||||||
|  |                     } | ||||||
|  |                     b.Func(name, | ||||||
|  |                            utils::Vector{ | ||||||
|  |                                b.Param(mat, matrix()), | ||||||
|  |                            }, | ||||||
|  |                            array(), | ||||||
|  |                            utils::Vector{ | ||||||
|  |                                b.Return(b.Construct(array(), columns)), | ||||||
|  |                            }); | ||||||
|  |                     return name; | ||||||
|  |                 }); | ||||||
|  |                 auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); | ||||||
|  |                 auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs)); | ||||||
|  |                 return b.Assign(lhs, rhs); | ||||||
|             } |             } | ||||||
|             MatrixInfo info = it->second; |  | ||||||
|             auto fn = utils::GetOrCreate(mat_to_arr, info, [&] { |  | ||||||
|                 auto name = |  | ||||||
|                     ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" + |  | ||||||
|                                            std::to_string(info.matrix->rows()) + "_stride_" + |  | ||||||
|                                            std::to_string(info.stride) + "_to_arr"); |  | ||||||
| 
 |  | ||||||
|                 auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; |  | ||||||
|                 auto array = [&] { return info.array(ctx.dst); }; |  | ||||||
| 
 |  | ||||||
|                 auto mat = ctx.dst->Sym("m"); |  | ||||||
|                 utils::Vector<const ast::Expression*, 4> columns; |  | ||||||
|                 for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) { |  | ||||||
|                     columns.Push(ctx.dst->IndexAccessor(mat, u32(i))); |  | ||||||
|                 } |  | ||||||
|                 ctx.dst->Func(name, |  | ||||||
|                               utils::Vector{ |  | ||||||
|                                   ctx.dst->Param(mat, matrix()), |  | ||||||
|                               }, |  | ||||||
|                               array(), |  | ||||||
|                               utils::Vector{ |  | ||||||
|                                   ctx.dst->Return(ctx.dst->Construct(array(), columns)), |  | ||||||
|                               }); |  | ||||||
|                 return name; |  | ||||||
|             }); |  | ||||||
|             auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); |  | ||||||
|             auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs)); |  | ||||||
|             return ctx.dst->Assign(lhs, rhs); |  | ||||||
|         } |         } | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     }); |     }); | ||||||
| @ -196,41 +167,40 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co | |||||||
|     //   m = arr_to_mat(ssbo.mat)
 |     //   m = arr_to_mat(ssbo.mat)
 | ||||||
|     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat; |     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat; | ||||||
|     ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { |     ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { | ||||||
|         if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) { |         if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) { | ||||||
|             auto it = decomposed.find(access->Member()->Declaration()); |             if (auto* info = decomposed.Find(access->Member()->Declaration())) { | ||||||
|             if (it == decomposed.end()) { |                 auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { | ||||||
|                 return nullptr; |                     auto name = | ||||||
|  |                         b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) + | ||||||
|  |                                         "x" + std::to_string(info->matrix->rows()) + "_stride_" + | ||||||
|  |                                         std::to_string(info->stride)); | ||||||
|  | 
 | ||||||
|  |                     auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; | ||||||
|  |                     auto array = [&] { return info->array(ctx.dst); }; | ||||||
|  | 
 | ||||||
|  |                     auto arr = b.Sym("arr"); | ||||||
|  |                     utils::Vector<const ast::Expression*, 4> columns; | ||||||
|  |                     for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { | ||||||
|  |                         columns.Push(b.IndexAccessor(arr, u32(i))); | ||||||
|  |                     } | ||||||
|  |                     b.Func(name, | ||||||
|  |                            utils::Vector{ | ||||||
|  |                                b.Param(arr, array()), | ||||||
|  |                            }, | ||||||
|  |                            matrix(), | ||||||
|  |                            utils::Vector{ | ||||||
|  |                                b.Return(b.Construct(matrix(), columns)), | ||||||
|  |                            }); | ||||||
|  |                     return name; | ||||||
|  |                 }); | ||||||
|  |                 return b.Call(fn, ctx.CloneWithoutTransform(expr)); | ||||||
|             } |             } | ||||||
|             MatrixInfo info = it->second; |  | ||||||
|             auto fn = utils::GetOrCreate(arr_to_mat, info, [&] { |  | ||||||
|                 auto name = ctx.dst->Symbols().New( |  | ||||||
|                     "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" + |  | ||||||
|                     std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride)); |  | ||||||
| 
 |  | ||||||
|                 auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; |  | ||||||
|                 auto array = [&] { return info.array(ctx.dst); }; |  | ||||||
| 
 |  | ||||||
|                 auto arr = ctx.dst->Sym("arr"); |  | ||||||
|                 utils::Vector<const ast::Expression*, 4> columns; |  | ||||||
|                 for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) { |  | ||||||
|                     columns.Push(ctx.dst->IndexAccessor(arr, u32(i))); |  | ||||||
|                 } |  | ||||||
|                 ctx.dst->Func(name, |  | ||||||
|                               utils::Vector{ |  | ||||||
|                                   ctx.dst->Param(arr, array()), |  | ||||||
|                               }, |  | ||||||
|                               matrix(), |  | ||||||
|                               utils::Vector{ |  | ||||||
|                                   ctx.dst->Return(ctx.dst->Construct(matrix(), columns)), |  | ||||||
|                               }); |  | ||||||
|                 return name; |  | ||||||
|             }); |  | ||||||
|             return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr)); |  | ||||||
|         } |         } | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -35,19 +35,10 @@ class DecomposeStridedMatrix final : public Castable<DecomposeStridedMatrix, Tra | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~DecomposeStridedMatrix() override; |     ~DecomposeStridedMatrix() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -27,14 +27,20 @@ DisableUniformityAnalysis::DisableUniformityAnalysis() = default; | |||||||
| 
 | 
 | ||||||
| DisableUniformityAnalysis::~DisableUniformityAnalysis() = default; | DisableUniformityAnalysis::~DisableUniformityAnalysis() = default; | ||||||
| 
 | 
 | ||||||
| bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src, | ||||||
|     return !program->Sem().Module()->Extensions().Contains( |                                                         const DataMap&, | ||||||
|         ast::Extension::kChromiumDisableUniformityAnalysis); |                                                         DataMap&) const { | ||||||
| } |     if (src->Sem().Module()->Extensions().Contains( | ||||||
|  |             ast::Extension::kChromiumDisableUniformityAnalysis)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |     b.Enable(ast::Extension::kChromiumDisableUniformityAnalysis); | ||||||
| 
 | 
 | ||||||
| void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis); |  | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -27,19 +27,10 @@ class DisableUniformityAnalysis final : public Castable<DisableUniformityAnalysi | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~DisableUniformityAnalysis() override; |     ~DisableUniformityAnalysis() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -31,11 +31,9 @@ using namespace tint::number_suffixes;  // NOLINT | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| ExpandCompoundAssignment::ExpandCompoundAssignment() = default; | namespace { | ||||||
| 
 | 
 | ||||||
| ExpandCompoundAssignment::~ExpandCompoundAssignment() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|         if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) { |         if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) { | ||||||
|             return true; |             return true; | ||||||
| @ -44,21 +42,10 @@ bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| namespace { | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| /// Internal class used to collect statement expansions during the transform.
 | /// PIMPL state for the transform
 | ||||||
| class State { | struct ExpandCompoundAssignment::State { | ||||||
|   private: |  | ||||||
|     /// The clone context.
 |  | ||||||
|     CloneContext& ctx; |  | ||||||
| 
 |  | ||||||
|     /// The program builder.
 |  | ||||||
|     ProgramBuilder& b; |  | ||||||
| 
 |  | ||||||
|     /// The HoistToDeclBefore helper instance.
 |  | ||||||
|     HoistToDeclBefore hoist_to_decl_before; |  | ||||||
| 
 |  | ||||||
|   public: |  | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param context the clone context
 |     /// @param context the clone context
 | ||||||
|     explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {} |     explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {} | ||||||
| @ -158,15 +145,32 @@ class State { | |||||||
|         ctx.Replace(stmt, b.Assign(new_lhs(), value)); |         ctx.Replace(stmt, b.Assign(new_lhs(), value)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Finalize the transformation and clone the module.
 |   private: | ||||||
|     void Finalize() { ctx.Clone(); } |     /// The clone context.
 | ||||||
|  |     CloneContext& ctx; | ||||||
|  | 
 | ||||||
|  |     /// The program builder.
 | ||||||
|  |     ProgramBuilder& b; | ||||||
|  | 
 | ||||||
|  |     /// The HoistToDeclBefore helper instance.
 | ||||||
|  |     HoistToDeclBefore hoist_to_decl_before; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | ExpandCompoundAssignment::ExpandCompoundAssignment() = default; | ||||||
| 
 | 
 | ||||||
| void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ExpandCompoundAssignment::~ExpandCompoundAssignment() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src, | ||||||
|  |                                                        const DataMap&, | ||||||
|  |                                                        DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|     State state(ctx); |     State state(ctx); | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) { |         if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) { | ||||||
|             state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op); |             state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op); | ||||||
|         } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) { |         } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) { | ||||||
| @ -175,7 +179,9 @@ void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) | |||||||
|             state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op); |             state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     state.Finalize(); | 
 | ||||||
|  |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -45,19 +45,13 @@ class ExpandCompoundAssignment final : public Castable<ExpandCompoundAssignment, | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~ExpandCompoundAssignment() override; |     ~ExpandCompoundAssignment() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   protected: |   private: | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     struct State; | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -35,6 +35,15 @@ namespace { | |||||||
| constexpr char kFirstVertexName[] = "first_vertex_index"; | constexpr char kFirstVertexName[] = "first_vertex_index"; | ||||||
| constexpr char kFirstInstanceName[] = "first_instance_index"; | constexpr char kFirstInstanceName[] = "first_instance_index"; | ||||||
| 
 | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* fn : program->AST().Functions()) { | ||||||
|  |         if (fn->PipelineStage() == ast::PipelineStage::kVertex) { | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| FirstIndexOffset::BindingPoint::BindingPoint() = default; | FirstIndexOffset::BindingPoint::BindingPoint() = default; | ||||||
| @ -49,16 +58,16 @@ FirstIndexOffset::Data::~Data() = default; | |||||||
| FirstIndexOffset::FirstIndexOffset() = default; | FirstIndexOffset::FirstIndexOffset() = default; | ||||||
| FirstIndexOffset::~FirstIndexOffset() = default; | FirstIndexOffset::~FirstIndexOffset() = default; | ||||||
| 
 | 
 | ||||||
| bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult FirstIndexOffset::Apply(const Program* src, | ||||||
|     for (auto* fn : program->AST().Functions()) { |                                                const DataMap& inputs, | ||||||
|         if (fn->PipelineStage() == ast::PipelineStage::kVertex) { |                                                DataMap& outputs) const { | ||||||
|             return true; |     if (!ShouldRun(src)) { | ||||||
|         } |         return SkipTransform; | ||||||
|     } |     } | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     // Get the uniform buffer binding point
 |     // Get the uniform buffer binding point
 | ||||||
|     uint32_t ub_binding = binding_; |     uint32_t ub_binding = binding_; | ||||||
|     uint32_t ub_group = group_; |     uint32_t ub_group = group_; | ||||||
| @ -115,17 +124,17 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou | |||||||
|     if (has_vertex_or_instance_index) { |     if (has_vertex_or_instance_index) { | ||||||
|         // Add uniform buffer members and calculate byte offsets
 |         // Add uniform buffer members and calculate byte offsets
 | ||||||
|         utils::Vector<const ast::StructMember*, 8> members; |         utils::Vector<const ast::StructMember*, 8> members; | ||||||
|         members.Push(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32())); |         members.Push(b.Member(kFirstVertexName, b.ty.u32())); | ||||||
|         members.Push(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32())); |         members.Push(b.Member(kFirstInstanceName, b.ty.u32())); | ||||||
|         auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members)); |         auto* struct_ = b.Structure(b.Sym(), std::move(members)); | ||||||
| 
 | 
 | ||||||
|         // Create a global to hold the uniform buffer
 |         // Create a global to hold the uniform buffer
 | ||||||
|         Symbol buffer_name = ctx.dst->Sym(); |         Symbol buffer_name = b.Sym(); | ||||||
|         ctx.dst->GlobalVar(buffer_name, ctx.dst->ty.Of(struct_), ast::AddressSpace::kUniform, |         b.GlobalVar(buffer_name, b.ty.Of(struct_), ast::AddressSpace::kUniform, | ||||||
|                            utils::Vector{ |                     utils::Vector{ | ||||||
|                                ctx.dst->Binding(AInt(ub_binding)), |                         b.Binding(AInt(ub_binding)), | ||||||
|                                ctx.dst->Group(AInt(ub_group)), |                         b.Group(AInt(ub_group)), | ||||||
|                            }); |                     }); | ||||||
| 
 | 
 | ||||||
|         // Fix up all references to the builtins with the offsets
 |         // Fix up all references to the builtins with the offsets
 | ||||||
|         ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* { |         ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* { | ||||||
| @ -150,9 +159,10 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou | |||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |  | ||||||
| 
 |  | ||||||
|     outputs.Add<Data>(has_vertex_or_instance_index); |     outputs.Add<Data>(has_vertex_or_instance_index); | ||||||
|  | 
 | ||||||
|  |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -103,19 +103,10 @@ class FirstIndexOffset final : public Castable<FirstIndexOffset, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~FirstIndexOffset() override; |     ~FirstIndexOffset() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     uint32_t binding_ = 0; |     uint32_t binding_ = 0; | ||||||
|  | |||||||
| @ -14,17 +14,17 @@ | |||||||
| 
 | 
 | ||||||
| #include "src/tint/transform/for_loop_to_loop.h" | #include "src/tint/transform/for_loop_to_loop.h" | ||||||
| 
 | 
 | ||||||
|  | #include <utility> | ||||||
|  | 
 | ||||||
| #include "src/tint/ast/break_statement.h" | #include "src/tint/ast/break_statement.h" | ||||||
| #include "src/tint/program_builder.h" | #include "src/tint/program_builder.h" | ||||||
| 
 | 
 | ||||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop); | TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop); | ||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| ForLoopToLoop::ForLoopToLoop() = default; | namespace { | ||||||
| 
 | 
 | ||||||
| ForLoopToLoop::~ForLoopToLoop() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|         if (node->Is<ast::ForLoopStatement>()) { |         if (node->Is<ast::ForLoopStatement>()) { | ||||||
|             return true; |             return true; | ||||||
| @ -33,19 +33,31 @@ bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | ForLoopToLoop::ForLoopToLoop() = default; | ||||||
|  | 
 | ||||||
|  | ForLoopToLoop::~ForLoopToLoop() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { |     ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { | ||||||
|         utils::Vector<const ast::Statement*, 8> stmts; |         utils::Vector<const ast::Statement*, 8> stmts; | ||||||
|         if (auto* cond = for_loop->condition) { |         if (auto* cond = for_loop->condition) { | ||||||
|             // !condition
 |             // !condition
 | ||||||
|             auto* not_cond = |             auto* not_cond = b.Not(ctx.Clone(cond)); | ||||||
|                 ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond)); |  | ||||||
| 
 | 
 | ||||||
|             // { break; }
 |             // { break; }
 | ||||||
|             auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>()); |             auto* break_body = b.Block(b.Break()); | ||||||
| 
 | 
 | ||||||
|             // if (!condition) { break; }
 |             // if (!condition) { break; }
 | ||||||
|             stmts.Push(ctx.dst->If(not_cond, break_body)); |             stmts.Push(b.If(not_cond, break_body)); | ||||||
|         } |         } | ||||||
|         for (auto* stmt : for_loop->body->statements) { |         for (auto* stmt : for_loop->body->statements) { | ||||||
|             stmts.Push(ctx.Clone(stmt)); |             stmts.Push(ctx.Clone(stmt)); | ||||||
| @ -53,20 +65,21 @@ void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
| 
 | 
 | ||||||
|         const ast::BlockStatement* continuing = nullptr; |         const ast::BlockStatement* continuing = nullptr; | ||||||
|         if (auto* cont = for_loop->continuing) { |         if (auto* cont = for_loop->continuing) { | ||||||
|             continuing = ctx.dst->Block(ctx.Clone(cont)); |             continuing = b.Block(ctx.Clone(cont)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         auto* body = ctx.dst->Block(stmts); |         auto* body = b.Block(stmts); | ||||||
|         auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing); |         auto* loop = b.Loop(body, continuing); | ||||||
| 
 | 
 | ||||||
|         if (auto* init = for_loop->initializer) { |         if (auto* init = for_loop->initializer) { | ||||||
|             return ctx.dst->Block(ctx.Clone(init), loop); |             return b.Block(ctx.Clone(init), loop); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return loop; |         return loop; | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -29,19 +29,10 @@ class ForLoopToLoop final : public Castable<ForLoopToLoop, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~ForLoopToLoop() override; |     ~ForLoopToLoop() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -32,70 +32,15 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment); | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// Private implementation of LocalizeStructArrayAssignment transform
 | /// PIMPL state for the transform
 | ||||||
| class LocalizeStructArrayAssignment::State { | struct LocalizeStructArrayAssignment::State { | ||||||
|   private: |  | ||||||
|     CloneContext& ctx; |  | ||||||
|     ProgramBuilder& b; |  | ||||||
| 
 |  | ||||||
|     /// Returns true if `expr` contains an index accessor expression to a
 |  | ||||||
|     /// structure member of array type.
 |  | ||||||
|     bool ContainsStructArrayIndex(const ast::Expression* expr) { |  | ||||||
|         bool result = false; |  | ||||||
|         ast::TraverseExpressions( |  | ||||||
|             expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { |  | ||||||
|                 // Indexing using a runtime value?
 |  | ||||||
|                 auto* idx_sem = ctx.src->Sem().Get(ia->index); |  | ||||||
|                 if (!idx_sem->ConstantValue()) { |  | ||||||
|                     // Indexing a member access expr?
 |  | ||||||
|                     if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) { |  | ||||||
|                         // That accesses an array?
 |  | ||||||
|                         if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) { |  | ||||||
|                             result = true; |  | ||||||
|                             return ast::TraverseAction::Stop; |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                 return ast::TraverseAction::Descend; |  | ||||||
|             }); |  | ||||||
| 
 |  | ||||||
|         return result; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Returns the type and address space of the originating variable of the lhs
 |  | ||||||
|     // of the assignment statement.
 |  | ||||||
|     // See https://www.w3.org/TR/WGSL/#originating-variable-section
 |  | ||||||
|     std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace( |  | ||||||
|         const ast::AssignmentStatement* assign_stmt) { |  | ||||||
|         auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable(); |  | ||||||
|         if (!source_var) { |  | ||||||
|             TINT_ICE(Transform, b.Diagnostics()) |  | ||||||
|                 << "Unable to determine originating variable for lhs of assignment " |  | ||||||
|                    "statement"; |  | ||||||
|             return {}; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         auto* type = source_var->Type(); |  | ||||||
|         if (auto* ref = type->As<sem::Reference>()) { |  | ||||||
|             return {ref->StoreType(), ref->AddressSpace()}; |  | ||||||
|         } else if (auto* ptr = type->As<sem::Pointer>()) { |  | ||||||
|             return {ptr->StoreType(), ptr->AddressSpace()}; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         TINT_ICE(Transform, b.Diagnostics()) |  | ||||||
|             << "Expecting to find variable of type pointer or reference on lhs " |  | ||||||
|                "of assignment statement"; |  | ||||||
|         return {}; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|   public: |  | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param ctx_in the CloneContext primed with the input program and
 |     /// @param program the source program
 | ||||||
|     /// ProgramBuilder
 |     explicit State(const Program* program) : src(program) {} | ||||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} |  | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|         struct Shared { |         struct Shared { | ||||||
|             bool process_nested_nodes = false; |             bool process_nested_nodes = false; | ||||||
|             utils::Vector<const ast::Statement*, 4> insert_before_stmts; |             utils::Vector<const ast::Statement*, 4> insert_before_stmts; | ||||||
| @ -189,6 +134,65 @@ class LocalizeStructArrayAssignment::State { | |||||||
|             }); |             }); | ||||||
| 
 | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   private: | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     /// The clone context
 | ||||||
|  |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     /// Returns true if `expr` contains an index accessor expression to a
 | ||||||
|  |     /// structure member of array type.
 | ||||||
|  |     bool ContainsStructArrayIndex(const ast::Expression* expr) { | ||||||
|  |         bool result = false; | ||||||
|  |         ast::TraverseExpressions( | ||||||
|  |             expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { | ||||||
|  |                 // Indexing using a runtime value?
 | ||||||
|  |                 auto* idx_sem = src->Sem().Get(ia->index); | ||||||
|  |                 if (!idx_sem->ConstantValue()) { | ||||||
|  |                     // Indexing a member access expr?
 | ||||||
|  |                     if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) { | ||||||
|  |                         // That accesses an array?
 | ||||||
|  |                         if (src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) { | ||||||
|  |                             result = true; | ||||||
|  |                             return ast::TraverseAction::Stop; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 return ast::TraverseAction::Descend; | ||||||
|  |             }); | ||||||
|  | 
 | ||||||
|  |         return result; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Returns the type and address space of the originating variable of the lhs
 | ||||||
|  |     // of the assignment statement.
 | ||||||
|  |     // See https://www.w3.org/TR/WGSL/#originating-variable-section
 | ||||||
|  |     std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace( | ||||||
|  |         const ast::AssignmentStatement* assign_stmt) { | ||||||
|  |         auto* source_var = src->Sem().Get(assign_stmt->lhs)->SourceVariable(); | ||||||
|  |         if (!source_var) { | ||||||
|  |             TINT_ICE(Transform, b.Diagnostics()) | ||||||
|  |                 << "Unable to determine originating variable for lhs of assignment " | ||||||
|  |                    "statement"; | ||||||
|  |             return {}; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto* type = source_var->Type(); | ||||||
|  |         if (auto* ref = type->As<sem::Reference>()) { | ||||||
|  |             return {ref->StoreType(), ref->AddressSpace()}; | ||||||
|  |         } else if (auto* ptr = type->As<sem::Pointer>()) { | ||||||
|  |             return {ptr->StoreType(), ptr->AddressSpace()}; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|  |             << "Expecting to find variable of type pointer or reference on lhs " | ||||||
|  |                "of assignment statement"; | ||||||
|  |         return {}; | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| @ -196,9 +200,10 @@ LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default; | |||||||
| 
 | 
 | ||||||
| LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default; | LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default; | ||||||
| 
 | 
 | ||||||
| void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src, | ||||||
|     State state(ctx); |                                                             const DataMap&, | ||||||
|     state.Run(); |                                                             DataMap&) const { | ||||||
|  |     return State{src}.Run(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -36,17 +36,13 @@ class LocalizeStructArrayAssignment final | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~LocalizeStructArrayAssignment() override; |     ~LocalizeStructArrayAssignment() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     class State; |     struct State; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -31,9 +31,9 @@ namespace tint::transform { | |||||||
| Manager::Manager() = default; | Manager::Manager() = default; | ||||||
| Manager::~Manager() = default; | Manager::~Manager() = default; | ||||||
| 
 | 
 | ||||||
| Output Manager::Run(const Program* program, const DataMap& data) const { | Transform::ApplyResult Manager::Apply(const Program* program, | ||||||
|     const Program* in = program; |                                       const DataMap& inputs, | ||||||
| 
 |                                       DataMap& outputs) const { | ||||||
| #if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM | #if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM | ||||||
|     auto print_program = [&](const char* msg, const Transform* transform) { |     auto print_program = [&](const char* msg, const Transform* transform) { | ||||||
|         auto wgsl = Program::printer(in); |         auto wgsl = Program::printer(in); | ||||||
| @ -46,34 +46,30 @@ Output Manager::Run(const Program* program, const DataMap& data) const { | |||||||
|     }; |     }; | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|     Output out; |     std::optional<Program> output; | ||||||
|  | 
 | ||||||
|     for (const auto& transform : transforms_) { |     for (const auto& transform : transforms_) { | ||||||
|         if (!transform->ShouldRun(in, data)) { |  | ||||||
|             TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name |  | ||||||
|                                             << std::endl); |  | ||||||
|             continue; |  | ||||||
|         } |  | ||||||
|         TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get())); |         TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get())); | ||||||
| 
 | 
 | ||||||
|         auto res = transform->Run(in, data); |         if (auto result = transform->Apply(program, inputs, outputs)) { | ||||||
|         out.program = std::move(res.program); |             output.emplace(std::move(result.value())); | ||||||
|         out.data.Add(std::move(res.data)); |             program = &output.value(); | ||||||
|         in = &out.program; |  | ||||||
|         if (!in->IsValid()) { |  | ||||||
|             TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); |  | ||||||
|             return out; |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         if (transform == transforms_.back()) { |             if (!program->IsValid()) { | ||||||
|             TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); |                 TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             if (transform == transforms_.back()) { | ||||||
|  |                 TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name | ||||||
|  |                                             << std::endl); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (program == in) { |     return output; | ||||||
|         out.program = program->Clone(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     return out; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -47,11 +47,10 @@ class Manager final : public Castable<Manager, Transform> { | |||||||
|         transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...)); |         transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Runs the transforms on `program`, returning the transformation result.
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param program the source program to transform
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @param data optional extra transform-specific input data
 |                       const DataMap& inputs, | ||||||
|     /// @returns the transformed program and diagnostics
 |                       DataMap& outputs) const override; | ||||||
|     Output Run(const Program* program, const DataMap& data = {}) const override; |  | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     std::vector<std::unique_ptr<Transform>> transforms_; |     std::vector<std::unique_ptr<Transform>> transforms_; | ||||||
|  | |||||||
| @ -65,15 +65,6 @@ MergeReturn::MergeReturn() = default; | |||||||
| 
 | 
 | ||||||
| MergeReturn::~MergeReturn() = default; | MergeReturn::~MergeReturn() = default; | ||||||
| 
 | 
 | ||||||
| bool MergeReturn::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* func : program->AST().Functions()) { |  | ||||||
|         if (NeedsTransform(program, func)) { |  | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| /// Internal class used to during the transform.
 | /// Internal class used to during the transform.
 | ||||||
| @ -223,7 +214,12 @@ class State { | |||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     bool made_changes = false; | ||||||
|  | 
 | ||||||
|     for (auto* func : ctx.src->AST().Functions()) { |     for (auto* func : ctx.src->AST().Functions()) { | ||||||
|         if (!NeedsTransform(ctx.src, func)) { |         if (!NeedsTransform(ctx.src, func)) { | ||||||
|             continue; |             continue; | ||||||
| @ -231,9 +227,15 @@ void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
| 
 | 
 | ||||||
|         State state(ctx, func); |         State state(ctx, func); | ||||||
|         state.ProcessStatement(func->body); |         state.ProcessStatement(func->body); | ||||||
|  |         made_changes = true; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (!made_changes) { | ||||||
|  |         return SkipTransform; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -27,19 +27,10 @@ class MergeReturn final : public Castable<MergeReturn, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~MergeReturn() override; |     ~MergeReturn() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -38,6 +38,15 @@ using WorkgroupParameterMemberList = utils::Vector<const ast::StructMember*, 8>; | |||||||
| // The name of the struct member for arrays that are wrapped in structures.
 | // The name of the struct member for arrays that are wrapped in structures.
 | ||||||
| const char* kWrappedArrayMemberName = "arr"; | const char* kWrappedArrayMemberName = "arr"; | ||||||
| 
 | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* decl : program->AST().GlobalDeclarations()) { | ||||||
|  |         if (decl->Is<ast::Variable>()) { | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Returns `true` if `type` is or contains a matrix type.
 | // Returns `true` if `type` is or contains a matrix type.
 | ||||||
| bool ContainsMatrix(const sem::Type* type) { | bool ContainsMatrix(const sem::Type* type) { | ||||||
|     type = type->UnwrapRef(); |     type = type->UnwrapRef(); | ||||||
| @ -56,7 +65,7 @@ bool ContainsMatrix(const sem::Type* type) { | |||||||
| } | } | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| /// State holds the current transform state.
 | /// PIMPL state for the transform
 | ||||||
| struct ModuleScopeVarToEntryPointParam::State { | struct ModuleScopeVarToEntryPointParam::State { | ||||||
|     /// The clone context.
 |     /// The clone context.
 | ||||||
|     CloneContext& ctx; |     CloneContext& ctx; | ||||||
| @ -501,19 +510,20 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default; | |||||||
| 
 | 
 | ||||||
| ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; | ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; | ||||||
| 
 | 
 | ||||||
| bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src, | ||||||
|     for (auto* decl : program->AST().GlobalDeclarations()) { |                                                               const DataMap&, | ||||||
|         if (decl->Is<ast::Variable>()) { |                                                               DataMap&) const { | ||||||
|             return true; |     if (!ShouldRun(src)) { | ||||||
|         } |         return SkipTransform; | ||||||
|     } |     } | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|     State state{ctx}; |     State state{ctx}; | ||||||
|     state.Process(); |     state.Process(); | ||||||
|  | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -69,20 +69,12 @@ class ModuleScopeVarToEntryPointParam final | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~ModuleScopeVarToEntryPointParam() override; |     ~ModuleScopeVarToEntryPointParam() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|  |   private: | ||||||
|     struct State; |     struct State; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -31,6 +31,17 @@ using namespace tint::number_suffixes;  // NOLINT | |||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|  |         if (auto* ty = node->As<ast::Type>()) { | ||||||
|  |             if (program->Sem().Get<sem::ExternalTexture>(ty)) { | ||||||
|  |                 return true; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /// This struct stores symbols for new bindings created as a result of transforming a
 | /// This struct stores symbols for new bindings created as a result of transforming a
 | ||||||
| /// texture_external instance.
 | /// texture_external instance.
 | ||||||
| struct NewBindingSymbols { | struct NewBindingSymbols { | ||||||
| @ -40,7 +51,7 @@ struct NewBindingSymbols { | |||||||
| }; | }; | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| /// State holds the current transform state
 | /// PIMPL state for the transform
 | ||||||
| struct MultiplanarExternalTexture::State { | struct MultiplanarExternalTexture::State { | ||||||
|     /// The clone context.
 |     /// The clone context.
 | ||||||
|     CloneContext& ctx; |     CloneContext& ctx; | ||||||
| @ -537,30 +548,26 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default; | |||||||
| MultiplanarExternalTexture::MultiplanarExternalTexture() = default; | MultiplanarExternalTexture::MultiplanarExternalTexture() = default; | ||||||
| MultiplanarExternalTexture::~MultiplanarExternalTexture() = default; | MultiplanarExternalTexture::~MultiplanarExternalTexture() = default; | ||||||
| 
 | 
 | ||||||
| bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |  | ||||||
|         if (auto* ty = node->As<ast::Type>()) { |  | ||||||
|             if (program->Sem().Get<sem::ExternalTexture>(ty)) { |  | ||||||
|                 return true; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Within this transform, an instance of a texture_external binding is unpacked into two
 | // Within this transform, an instance of a texture_external binding is unpacked into two
 | ||||||
| // texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
 | // texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
 | ||||||
| // buffer binding representing a struct of parameters. Calls to texture builtins that contain a
 | // buffer binding representing a struct of parameters. Calls to texture builtins that contain a
 | ||||||
| // texture_external parameter will be transformed into a newly generated version of the function,
 | // texture_external parameter will be transformed into a newly generated version of the function,
 | ||||||
| // which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
 | // which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
 | ||||||
| void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src, | ||||||
|  |                                                          const DataMap& inputs, | ||||||
|  |                                                          DataMap&) const { | ||||||
|     auto* new_binding_points = inputs.Get<NewBindingPoints>(); |     auto* new_binding_points = inputs.Get<NewBindingPoints>(); | ||||||
| 
 | 
 | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|     if (!new_binding_points) { |     if (!new_binding_points) { | ||||||
|         ctx.dst->Diagnostics().add_error( |         b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " + | ||||||
|             diag::System::Transform, |                                                                std::string(TypeInfo().name)); | ||||||
|             "missing new binding point data for " + std::string(TypeInfo().name)); |         return Program(std::move(b)); | ||||||
|         return; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     State state(ctx, new_binding_points); |     State state(ctx, new_binding_points); | ||||||
| @ -568,6 +575,7 @@ void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, D | |||||||
|     state.Process(); |     state.Process(); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -80,21 +80,13 @@ class MultiplanarExternalTexture final : public Castable<MultiplanarExternalText | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~MultiplanarExternalTexture() override; |     ~MultiplanarExternalTexture() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   protected: |   private: | ||||||
|     struct State; |     struct State; | ||||||
| 
 |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -23,7 +23,11 @@ using MultiplanarExternalTextureTest = TransformTest; | |||||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) { | TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) { | ||||||
|     auto* src = R"()"; |     auto* src = R"()"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src)); |     DataMap data; | ||||||
|  |     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||||
|  |         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||||
|  | 
 | ||||||
|  |     EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { | TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { | ||||||
| @ -31,14 +35,22 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { | |||||||
| type ET = texture_external; | type ET = texture_external; | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src)); |     DataMap data; | ||||||
|  |     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||||
|  |         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||||
|  | 
 | ||||||
|  |     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||||
| } | } | ||||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) { | TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) { | ||||||
|     auto* src = R"( |     auto* src = R"( | ||||||
| @group(0) @binding(0) var ext_tex : texture_external; | @group(0) @binding(0) var ext_tex : texture_external; | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src)); |     DataMap data; | ||||||
|  |     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||||
|  |         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||||
|  | 
 | ||||||
|  |     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { | TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { | ||||||
| @ -46,7 +58,11 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { | |||||||
| fn f(ext_tex : texture_external) {} | fn f(ext_tex : texture_external) {} | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src)); |     DataMap data; | ||||||
|  |     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||||
|  |         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||||
|  | 
 | ||||||
|  |     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Running the transform without passing in data for the new bindings should result in an error.
 | // Running the transform without passing in data for the new bindings should result in an error.
 | ||||||
|  | |||||||
| @ -29,6 +29,18 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config); | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| namespace { | namespace { | ||||||
|  | 
 | ||||||
|  | bool ShouldRun(const Program* program) { | ||||||
|  |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|  |         if (auto* attr = node->As<ast::BuiltinAttribute>()) { | ||||||
|  |             if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) { | ||||||
|  |                 return true; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /// Accessor describes the identifiers used in a member accessor that is being
 | /// Accessor describes the identifiers used in a member accessor that is being
 | ||||||
| /// used to retrieve the num_workgroups builtin from a parameter.
 | /// used to retrieve the num_workgroups builtin from a parameter.
 | ||||||
| struct Accessor { | struct Accessor { | ||||||
| @ -44,41 +56,40 @@ struct Accessor { | |||||||
|         size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); } |         size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); } | ||||||
|     }; |     }; | ||||||
| }; | }; | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default; | NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default; | ||||||
| NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default; | NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default; | ||||||
| 
 | 
 | ||||||
| bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src, | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |                                                        const DataMap& inputs, | ||||||
|         if (auto* attr = node->As<ast::BuiltinAttribute>()) { |                                                        DataMap&) const { | ||||||
|             if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) { |     ProgramBuilder b; | ||||||
|                 return true; |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { |  | ||||||
|     auto* cfg = inputs.Get<Config>(); |     auto* cfg = inputs.Get<Config>(); | ||||||
|     if (cfg == nullptr) { |     if (cfg == nullptr) { | ||||||
|         ctx.dst->Diagnostics().add_error( |         b.Diagnostics().add_error(diag::System::Transform, | ||||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||||
|         return; |         return Program(std::move(b)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     const char* kNumWorkgroupsMemberName = "num_workgroups"; |     const char* kNumWorkgroupsMemberName = "num_workgroups"; | ||||||
| 
 | 
 | ||||||
|     // Find all entry point parameters that declare the num_workgroups builtin.
 |     // Find all entry point parameters that declare the num_workgroups builtin.
 | ||||||
|     std::unordered_set<Accessor, Accessor::Hasher> to_replace; |     std::unordered_set<Accessor, Accessor::Hasher> to_replace; | ||||||
|     for (auto* func : ctx.src->AST().Functions()) { |     for (auto* func : src->AST().Functions()) { | ||||||
|         // num_workgroups is only valid for compute stages.
 |         // num_workgroups is only valid for compute stages.
 | ||||||
|         if (func->PipelineStage() != ast::PipelineStage::kCompute) { |         if (func->PipelineStage() != ast::PipelineStage::kCompute) { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         for (auto* param : ctx.src->Sem().Get(func)->Parameters()) { |         for (auto* param : src->Sem().Get(func)->Parameters()) { | ||||||
|             // Because the CanonicalizeEntryPointIO transform has been run, builtins
 |             // Because the CanonicalizeEntryPointIO transform has been run, builtins
 | ||||||
|             // will only appear as struct members.
 |             // will only appear as struct members.
 | ||||||
|             auto* str = param->Type()->As<sem::Struct>(); |             auto* str = param->Type()->As<sem::Struct>(); | ||||||
| @ -108,7 +119,7 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|                 // If this is the only member, remove the struct and parameter too.
 |                 // If this is the only member, remove the struct and parameter too.
 | ||||||
|                 if (str->Members().size() == 1) { |                 if (str->Members().size() == 1) { | ||||||
|                     ctx.Remove(func->params, param->Declaration()); |                     ctx.Remove(func->params, param->Declaration()); | ||||||
|                     ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration()); |                     ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration()); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @ -119,11 +130,10 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|     const ast::Variable* num_workgroups_ubo = nullptr; |     const ast::Variable* num_workgroups_ubo = nullptr; | ||||||
|     auto get_ubo = [&]() { |     auto get_ubo = [&]() { | ||||||
|         if (!num_workgroups_ubo) { |         if (!num_workgroups_ubo) { | ||||||
|             auto* num_workgroups_struct = ctx.dst->Structure( |             auto* num_workgroups_struct = | ||||||
|                 ctx.dst->Sym(), |                 b.Structure(b.Sym(), utils::Vector{ | ||||||
|                 utils::Vector{ |                                          b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())), | ||||||
|                     ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32())), |                                      }); | ||||||
|                 }); |  | ||||||
| 
 | 
 | ||||||
|             uint32_t group, binding; |             uint32_t group, binding; | ||||||
|             if (cfg->ubo_binding.has_value()) { |             if (cfg->ubo_binding.has_value()) { | ||||||
| @ -135,9 +145,9 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|                 // plus 1, or group 0 if no resource bound.
 |                 // plus 1, or group 0 if no resource bound.
 | ||||||
|                 group = 0; |                 group = 0; | ||||||
| 
 | 
 | ||||||
|                 for (auto* global : ctx.src->AST().GlobalVariables()) { |                 for (auto* global : src->AST().GlobalVariables()) { | ||||||
|                     if (global->HasBindingPoint()) { |                     if (global->HasBindingPoint()) { | ||||||
|                         auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(global); |                         auto* global_sem = src->Sem().Get<sem::GlobalVariable>(global); | ||||||
|                         auto binding_point = global_sem->BindingPoint(); |                         auto binding_point = global_sem->BindingPoint(); | ||||||
|                         if (binding_point.group >= group) { |                         if (binding_point.group >= group) { | ||||||
|                             group = binding_point.group + 1; |                             group = binding_point.group + 1; | ||||||
| @ -148,16 +158,16 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|                 binding = 0; |                 binding = 0; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             num_workgroups_ubo = ctx.dst->GlobalVar( |             num_workgroups_ubo = | ||||||
|                 ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform, |                 b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform, | ||||||
|                 ctx.dst->Group(AInt(group)), ctx.dst->Binding(AInt(binding))); |                             b.Group(AInt(group)), b.Binding(AInt(binding))); | ||||||
|         } |         } | ||||||
|         return num_workgroups_ubo; |         return num_workgroups_ubo; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     // Now replace all the places where the builtins are accessed with the value
 |     // Now replace all the places where the builtins are accessed with the value
 | ||||||
|     // loaded from the uniform buffer.
 |     // loaded from the uniform buffer.
 | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         auto* accessor = node->As<ast::MemberAccessorExpression>(); |         auto* accessor = node->As<ast::MemberAccessorExpression>(); | ||||||
|         if (!accessor) { |         if (!accessor) { | ||||||
|             continue; |             continue; | ||||||
| @ -168,12 +178,12 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if (to_replace.count({ident->symbol, accessor->member->symbol})) { |         if (to_replace.count({ident->symbol, accessor->member->symbol})) { | ||||||
|             ctx.Replace(accessor, |             ctx.Replace(accessor, b.MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName)); | ||||||
|                         ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName)); |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp) | NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp) | ||||||
|  | |||||||
| @ -72,19 +72,10 @@ class NumWorkgroupsFromUniform final : public Castable<NumWorkgroupsFromUniform, | |||||||
|         std::optional<sem::BindingPoint> ubo_binding; |         std::optional<sem::BindingPoint> ubo_binding; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -28,7 +28,9 @@ using NumWorkgroupsFromUniformTest = TransformTest; | |||||||
| TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) { | TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) { | ||||||
|     auto* src = R"()"; |     auto* src = R"()"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src)); |     DataMap data; | ||||||
|  |     data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u}); | ||||||
|  |     EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) { | TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) { | ||||||
| @ -38,7 +40,9 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) { | |||||||
| } | } | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src)); |     DataMap data; | ||||||
|  |     data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u}); | ||||||
|  |     EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src, data)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { | TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { | ||||||
| @ -55,7 +59,6 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) { | |||||||
|     DataMap data; |     DataMap data; | ||||||
|     data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl); |     data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl); | ||||||
|     auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data); |     auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data); | ||||||
| 
 |  | ||||||
|     EXPECT_EQ(expect, str(got)); |     EXPECT_EQ(expect, str(got)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -33,14 +33,15 @@ using namespace tint::number_suffixes;  // NOLINT | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for the PackedVec3 transform
 | /// PIMPL state for the transform
 | ||||||
| struct PackedVec3::State { | struct PackedVec3::State { | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param c the CloneContext
 |     /// @param program the source program
 | ||||||
|     explicit State(CloneContext& c) : ctx(c) {} |     explicit State(const Program* program) : src(program) {} | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|         // Packed vec3<T> struct members
 |         // Packed vec3<T> struct members
 | ||||||
|         utils::Hashset<const sem::StructMember*, 8> members; |         utils::Hashset<const sem::StructMember*, 8> members; | ||||||
| 
 | 
 | ||||||
| @ -72,6 +73,10 @@ struct PackedVec3::State { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         if (members.IsEmpty()) { | ||||||
|  |             return SkipTransform; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         // Walk the nodes, starting with the most deeply nested, finding all the AST expressions
 |         // Walk the nodes, starting with the most deeply nested, finding all the AST expressions
 | ||||||
|         // that load a whole packed vector (not a scalar / swizzle of the vector).
 |         // that load a whole packed vector (not a scalar / swizzle of the vector).
 | ||||||
|         utils::Hashset<const sem::Expression*, 16> refs; |         utils::Hashset<const sem::Expression*, 16> refs; | ||||||
| @ -137,36 +142,20 @@ struct PackedVec3::State { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|     } |         return Program(std::move(b)); | ||||||
| 
 |  | ||||||
|     /// @returns true if this transform should be run for the given program
 |  | ||||||
|     /// @param program the program to inspect
 |  | ||||||
|     static bool ShouldRun(const Program* program) { |  | ||||||
|         for (auto* decl : program->AST().GlobalDeclarations()) { |  | ||||||
|             if (auto* str = program->Sem().Get<sem::Struct>(decl)) { |  | ||||||
|                 if (str->IsHostShareable()) { |  | ||||||
|                     for (auto* member : str->Members()) { |  | ||||||
|                         if (auto* vec = member->Type()->As<sem::Vector>()) { |  | ||||||
|                             if (vec->Width() == 3) { |  | ||||||
|                                 return true; |  | ||||||
|                             } |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return false; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|     /// Alias to the semantic info in ctx.src
 |     /// Alias to the semantic info in ctx.src
 | ||||||
|     const sem::Info& sem = ctx.src->Sem(); |     const sem::Info& sem = ctx.src->Sem(); | ||||||
|     /// Alias to the symbols in ctx.src
 |     /// Alias to the symbols in ctx.src
 | ||||||
|     const SymbolTable& sym = ctx.src->Symbols(); |     const SymbolTable& sym = ctx.src->Symbols(); | ||||||
|     /// Alias to the ctx.dst program builder
 |  | ||||||
|     ProgramBuilder& b = *ctx.dst; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {} | PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {} | ||||||
| @ -183,12 +172,8 @@ std::string PackedVec3::Attribute::InternalName() const { | |||||||
| PackedVec3::PackedVec3() = default; | PackedVec3::PackedVec3() = default; | ||||||
| PackedVec3::~PackedVec3() = default; | PackedVec3::~PackedVec3() = default; | ||||||
| 
 | 
 | ||||||
| bool PackedVec3::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     return State::ShouldRun(program); |     return State{src}.Run(); | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void PackedVec3::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     State(ctx).Run(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -56,21 +56,13 @@ class PackedVec3 final : public Castable<PackedVec3, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~PackedVec3() override; |     ~PackedVec3() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     struct State; |     struct State; | ||||||
| 
 |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -50,8 +50,10 @@ PadStructs::PadStructs() = default; | |||||||
| 
 | 
 | ||||||
| PadStructs::~PadStructs() = default; | PadStructs::~PadStructs() = default; | ||||||
| 
 | 
 | ||||||
| void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     auto& sem = ctx.src->Sem(); |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |     auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|     std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs; |     std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs; | ||||||
|     utils::Hashset<const ast::StructMember*, 8> padding_members; |     utils::Hashset<const ast::StructMember*, 8> padding_members; | ||||||
| @ -65,7 +67,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|         bool has_runtime_sized_array = false; |         bool has_runtime_sized_array = false; | ||||||
|         utils::Vector<const ast::StructMember*, 8> new_members; |         utils::Vector<const ast::StructMember*, 8> new_members; | ||||||
|         for (auto* mem : str->Members()) { |         for (auto* mem : str->Members()) { | ||||||
|             auto name = ctx.src->Symbols().NameFor(mem->Name()); |             auto name = src->Symbols().NameFor(mem->Name()); | ||||||
| 
 | 
 | ||||||
|             if (offset < mem->Offset()) { |             if (offset < mem->Offset()) { | ||||||
|                 CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset); |                 CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset); | ||||||
| @ -75,7 +77,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|             auto* ty = mem->Type(); |             auto* ty = mem->Type(); | ||||||
|             const ast::Type* type = CreateASTTypeFor(ctx, ty); |             const ast::Type* type = CreateASTTypeFor(ctx, ty); | ||||||
| 
 | 
 | ||||||
|             new_members.Push(ctx.dst->Member(name, type)); |             new_members.Push(b.Member(name, type)); | ||||||
| 
 | 
 | ||||||
|             uint32_t size = ty->Size(); |             uint32_t size = ty->Size(); | ||||||
|             if (ty->Is<sem::Struct>() && str->UsedAs(ast::AddressSpace::kUniform)) { |             if (ty->Is<sem::Struct>() && str->UsedAs(ast::AddressSpace::kUniform)) { | ||||||
| @ -97,8 +99,8 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|         if (offset < struct_size && !has_runtime_sized_array) { |         if (offset < struct_size && !has_runtime_sized_array) { | ||||||
|             CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset); |             CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset); | ||||||
|         } |         } | ||||||
|         auto* new_struct = ctx.dst->create<ast::Struct>(ctx.Clone(ast_str->name), |         auto* new_struct = | ||||||
|                                                         std::move(new_members), utils::Empty); |             b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members), utils::Empty); | ||||||
|         replaced_structs[ast_str] = new_struct; |         replaced_structs[ast_str] = new_struct; | ||||||
|         return new_struct; |         return new_struct; | ||||||
|     }); |     }); | ||||||
| @ -131,16 +133,17 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|         auto* arg = ast_call->args.begin(); |         auto* arg = ast_call->args.begin(); | ||||||
|         for (auto* member : new_struct->members) { |         for (auto* member : new_struct->members) { | ||||||
|             if (padding_members.Contains(member)) { |             if (padding_members.Contains(member)) { | ||||||
|                 new_args.Push(ctx.dst->Expr(0_u)); |                 new_args.Push(b.Expr(0_u)); | ||||||
|             } else { |             } else { | ||||||
|                 new_args.Push(ctx.Clone(*arg)); |                 new_args.Push(ctx.Clone(*arg)); | ||||||
|                 arg++; |                 arg++; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         return ctx.dst->Construct(CreateASTTypeFor(ctx, str), new_args); |         return b.Construct(CreateASTTypeFor(ctx, str), new_args); | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -30,14 +30,10 @@ class PadStructs final : public Castable<PadStructs, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~PadStructs() override; |     ~PadStructs() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -13,6 +13,9 @@ | |||||||
| // limitations under the License.
 | // limitations under the License.
 | ||||||
| 
 | 
 | ||||||
| #include "src/tint/transform/promote_initializers_to_let.h" | #include "src/tint/transform/promote_initializers_to_let.h" | ||||||
|  | 
 | ||||||
|  | #include <utility> | ||||||
|  | 
 | ||||||
| #include "src/tint/program_builder.h" | #include "src/tint/program_builder.h" | ||||||
| #include "src/tint/sem/call.h" | #include "src/tint/sem/call.h" | ||||||
| #include "src/tint/sem/statement.h" | #include "src/tint/sem/statement.h" | ||||||
| @ -27,9 +30,16 @@ PromoteInitializersToLet::PromoteInitializersToLet() = default; | |||||||
| 
 | 
 | ||||||
| PromoteInitializersToLet::~PromoteInitializersToLet() = default; | PromoteInitializersToLet::~PromoteInitializersToLet() = default; | ||||||
| 
 | 
 | ||||||
| void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src, | ||||||
|  |                                                        const DataMap&, | ||||||
|  |                                                        DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     HoistToDeclBefore hoist_to_decl_before(ctx); |     HoistToDeclBefore hoist_to_decl_before(ctx); | ||||||
| 
 | 
 | ||||||
|  |     bool any_promoted = false; | ||||||
|  | 
 | ||||||
|     // Hoists array and structure initializers to a constant variable, declared
 |     // Hoists array and structure initializers to a constant variable, declared
 | ||||||
|     // just before the statement of usage.
 |     // just before the statement of usage.
 | ||||||
|     auto promote = [&](const sem::Expression* expr) { |     auto promote = [&](const sem::Expression* expr) { | ||||||
| @ -59,14 +69,15 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) | |||||||
|             return true; |             return true; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         any_promoted = true; | ||||||
|         return hoist_to_decl_before.Add(expr, expr->Declaration(), true); |         return hoist_to_decl_before.Add(expr, expr->Declaration(), true); | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         bool ok = Switch( |         bool ok = Switch( | ||||||
|             node,  //
 |             node,  //
 | ||||||
|             [&](const ast::CallExpression* expr) { |             [&](const ast::CallExpression* expr) { | ||||||
|                 if (auto* sem = ctx.src->Sem().Get(expr)) { |                 if (auto* sem = src->Sem().Get(expr)) { | ||||||
|                     auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>(); |                     auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>(); | ||||||
|                     if (ctor->Target()->Is<sem::TypeInitializer>()) { |                     if (ctor->Target()->Is<sem::TypeInitializer>()) { | ||||||
|                         return promote(sem); |                         return promote(sem); | ||||||
| @ -75,7 +86,7 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) | |||||||
|                 return true; |                 return true; | ||||||
|             }, |             }, | ||||||
|             [&](const ast::IdentifierExpression* expr) { |             [&](const ast::IdentifierExpression* expr) { | ||||||
|                 if (auto* sem = ctx.src->Sem().Get(expr)) { |                 if (auto* sem = src->Sem().Get(expr)) { | ||||||
|                     if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) { |                     if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) { | ||||||
|                         // Identifier resolves to a variable
 |                         // Identifier resolves to a variable
 | ||||||
|                         if (auto* stmt = user->Stmt()) { |                         if (auto* stmt = user->Stmt()) { | ||||||
| @ -96,13 +107,17 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) | |||||||
|                 return true; |                 return true; | ||||||
|             }, |             }, | ||||||
|             [&](Default) { return true; }); |             [&](Default) { return true; }); | ||||||
| 
 |  | ||||||
|         if (!ok) { |         if (!ok) { | ||||||
|             return; |             return Program(std::move(b)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     if (!any_promoted) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -33,14 +33,10 @@ class PromoteInitializersToLet final : public Castable<PromoteInitializersToLet, | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~PromoteInitializersToLet() override; |     ~PromoteInitializersToLet() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -53,34 +53,36 @@ class StateBase { | |||||||
| // to else {if}s so that the next transform, DecomposeSideEffects, can insert
 | // to else {if}s so that the next transform, DecomposeSideEffects, can insert
 | ||||||
| // hoisted expressions above their current location.
 | // hoisted expressions above their current location.
 | ||||||
| struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> { | struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> { | ||||||
|     class State; |     ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override; | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class SimplifySideEffectStatements::State : public StateBase { | Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src, | ||||||
|     HoistToDeclBefore hoist_to_decl_before; |                                                            const DataMap&, | ||||||
|  |                                                            DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
| 
 | 
 | ||||||
|   public: |     bool made_changes = false; | ||||||
|     explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {} |  | ||||||
| 
 | 
 | ||||||
|     void Run() { |     HoistToDeclBefore hoist_to_decl_before(ctx); | ||||||
|         for (auto* node : ctx.src->ASTNodes().Objects()) { |     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||||
|             if (auto* expr = node->As<ast::Expression>()) { |         if (auto* expr = node->As<ast::Expression>()) { | ||||||
|                 auto* sem_expr = sem.Get(expr); |             auto* sem_expr = src->Sem().Get(expr); | ||||||
|                 if (!sem_expr || !sem_expr->HasSideEffects()) { |             if (!sem_expr || !sem_expr->HasSideEffects()) { | ||||||
|                     continue; |                 continue; | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 hoist_to_decl_before.Prepare(sem_expr); |  | ||||||
|             } |             } | ||||||
|         } |  | ||||||
|         ctx.Clone(); |  | ||||||
|     } |  | ||||||
| }; |  | ||||||
| 
 | 
 | ||||||
| void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |             hoist_to_decl_before.Prepare(sem_expr); | ||||||
|     State state(ctx); |             made_changes = true; | ||||||
|     state.Run(); |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (!made_changes) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Decomposes side-effecting expressions to ensure order of evaluation. This
 | // Decomposes side-effecting expressions to ensure order of evaluation. This
 | ||||||
| @ -89,7 +91,7 @@ void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMa | |||||||
| struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> { | struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> { | ||||||
|     class CollectHoistsState; |     class CollectHoistsState; | ||||||
|     class DecomposeState; |     class DecomposeState; | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override; |     ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // CollectHoistsState traverses the AST top-down, identifying which expressions
 | // CollectHoistsState traverses the AST top-down, identifying which expressions
 | ||||||
| @ -667,12 +669,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase { | |||||||
|             } |             } | ||||||
|             return nullptr; |             return nullptr; | ||||||
|         }); |         }); | ||||||
| 
 |  | ||||||
|         ctx.Clone(); |  | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src, | ||||||
|  |                                                    const DataMap&, | ||||||
|  |                                                    DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     // First collect side-effecting expressions to hoist
 |     // First collect side-effecting expressions to hoist
 | ||||||
|     CollectHoistsState collect_hoists_state{ctx}; |     CollectHoistsState collect_hoists_state{ctx}; | ||||||
|     auto to_hoist = collect_hoists_state.Run(); |     auto to_hoist = collect_hoists_state.Run(); | ||||||
| @ -680,6 +685,9 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
|     // Now decompose these expressions
 |     // Now decompose these expressions
 | ||||||
|     DecomposeState decompose_state{ctx, std::move(to_hoist)}; |     DecomposeState decompose_state{ctx, std::move(to_hoist)}; | ||||||
|     decompose_state.Run(); |     decompose_state.Run(); | ||||||
|  | 
 | ||||||
|  |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| @ -687,13 +695,13 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | |||||||
| PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default; | PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default; | ||||||
| PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default; | PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default; | ||||||
| 
 | 
 | ||||||
| Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const { | Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src, | ||||||
|  |                                                        const DataMap& inputs, | ||||||
|  |                                                        DataMap& outputs) const { | ||||||
|     transform::Manager manager; |     transform::Manager manager; | ||||||
|     manager.Add<SimplifySideEffectStatements>(); |     manager.Add<SimplifySideEffectStatements>(); | ||||||
|     manager.Add<DecomposeSideEffects>(); |     manager.Add<DecomposeSideEffects>(); | ||||||
| 
 |     return manager.Apply(src, inputs, outputs); | ||||||
|     auto output = manager.Run(program, data); |  | ||||||
|     return output; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -31,12 +31,10 @@ class PromoteSideEffectsToDecl final : public Castable<PromoteSideEffectsToDecl, | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~PromoteSideEffectsToDecl() override; |     ~PromoteSideEffectsToDecl() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform on `program`, returning the transformation result.
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @param program the source program to transform
 |                       const DataMap& inputs, | ||||||
|     /// @param data optional extra transform-specific data
 |                       DataMap& outputs) const override; | ||||||
|     /// @returns the transformation result
 |  | ||||||
|     Output Run(const Program* program, const DataMap& data = {}) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -32,53 +32,19 @@ | |||||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveContinueInSwitch); | TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveContinueInSwitch); | ||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| namespace { |  | ||||||
| 
 | 
 | ||||||
| class State { | /// PIMPL state for the transform
 | ||||||
|   private: | struct RemoveContinueInSwitch::State { | ||||||
|     CloneContext& ctx; |  | ||||||
|     ProgramBuilder& b; |  | ||||||
|     const sem::Info& sem; |  | ||||||
| 
 |  | ||||||
|     // Map of switch statement to 'tint_continue' variable.
 |  | ||||||
|     std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name; |  | ||||||
| 
 |  | ||||||
|     // If `cont` is within a switch statement within a loop, returns a pointer to
 |  | ||||||
|     // that switch statement.
 |  | ||||||
|     static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, |  | ||||||
|                                                              const ast::ContinueStatement* cont) { |  | ||||||
|         // Find whether first parent is a switch or a loop
 |  | ||||||
|         auto* sem_stmt = sem.Get(cont); |  | ||||||
|         auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement, |  | ||||||
|                                                      sem::ForLoopStatement, sem::WhileStatement>(); |  | ||||||
|         if (!sem_parent) { |  | ||||||
|             return nullptr; |  | ||||||
|         } |  | ||||||
|         return sem_parent->Declaration()->As<ast::SwitchStatement>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|   public: |  | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param ctx_in the context
 |     /// @param program the source program
 | ||||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} |     explicit State(const Program* program) : src(program) {} | ||||||
| 
 |  | ||||||
|     /// Returns true if this transform should be run for the given program
 |  | ||||||
|     static bool ShouldRun(const Program* program) { |  | ||||||
|         for (auto* node : program->ASTNodes().Objects()) { |  | ||||||
|             auto* stmt = node->As<ast::ContinueStatement>(); |  | ||||||
|             if (!stmt) { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|             if (GetParentSwitchInLoop(program->Sem(), stmt)) { |  | ||||||
|                 return true; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|         for (auto* node : ctx.src->ASTNodes().Objects()) { |     ApplyResult Run() { | ||||||
|  |         bool made_changes = false; | ||||||
|  | 
 | ||||||
|  |         for (auto* node : src->ASTNodes().Objects()) { | ||||||
|             auto* cont = node->As<ast::ContinueStatement>(); |             auto* cont = node->As<ast::ContinueStatement>(); | ||||||
|             if (!cont) { |             if (!cont) { | ||||||
|                 continue; |                 continue; | ||||||
| @ -90,6 +56,8 @@ class State { | |||||||
|                 continue; |                 continue; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  |             made_changes = true; | ||||||
|  | 
 | ||||||
|             auto cont_var_name = |             auto cont_var_name = | ||||||
|                 tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() { |                 tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() { | ||||||
|                     // Create and insert 'var tint_continue : bool = false;' before the
 |                     // Create and insert 'var tint_continue : bool = false;' before the
 | ||||||
| @ -116,22 +84,50 @@ class State { | |||||||
|             ctx.Replace(cont, new_stmt); |             ctx.Replace(cont, new_stmt); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         if (!made_changes) { | ||||||
|  |             return SkipTransform; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   private: | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     /// The clone context
 | ||||||
|  |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |     /// Alias to src->sem
 | ||||||
|  |     const sem::Info& sem = src->Sem(); | ||||||
|  | 
 | ||||||
|  |     // Map of switch statement to 'tint_continue' variable.
 | ||||||
|  |     std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name; | ||||||
|  | 
 | ||||||
|  |     // If `cont` is within a switch statement within a loop, returns a pointer to
 | ||||||
|  |     // that switch statement.
 | ||||||
|  |     static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, | ||||||
|  |                                                              const ast::ContinueStatement* cont) { | ||||||
|  |         // Find whether first parent is a switch or a loop
 | ||||||
|  |         auto* sem_stmt = sem.Get(cont); | ||||||
|  |         auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement, | ||||||
|  |                                                      sem::ForLoopStatement, sem::WhileStatement>(); | ||||||
|  |         if (!sem_parent) { | ||||||
|  |             return nullptr; | ||||||
|  |         } | ||||||
|  |         return sem_parent->Declaration()->As<ast::SwitchStatement>(); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| RemoveContinueInSwitch::RemoveContinueInSwitch() = default; | RemoveContinueInSwitch::RemoveContinueInSwitch() = default; | ||||||
| RemoveContinueInSwitch::~RemoveContinueInSwitch() = default; | RemoveContinueInSwitch::~RemoveContinueInSwitch() = default; | ||||||
| 
 | 
 | ||||||
| bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const { | Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src, | ||||||
|     return State::ShouldRun(program); |                                                      const DataMap&, | ||||||
| } |                                                      DataMap&) const { | ||||||
| 
 |     State state(src); | ||||||
| void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |     return state.Run(); | ||||||
|     State state(ctx); |  | ||||||
|     state.Run(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -31,19 +31,13 @@ class RemoveContinueInSwitch final : public Castable<RemoveContinueInSwitch, Tra | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~RemoveContinueInSwitch() override; |     ~RemoveContinueInSwitch() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param program the program to inspect
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @param data optional extra transform-specific input data
 |                       const DataMap& inputs, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       DataMap& outputs) const override; | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |  | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |   private: | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |     struct State; | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -41,34 +41,25 @@ RemovePhonies::RemovePhonies() = default; | |||||||
| 
 | 
 | ||||||
| RemovePhonies::~RemovePhonies() = default; | RemovePhonies::~RemovePhonies() = default; | ||||||
| 
 | 
 | ||||||
| bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     ProgramBuilder b; | ||||||
|         if (node->Is<ast::PhonyExpression>()) { |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
|         if (auto* stmt = node->As<ast::CallStatement>()) { |  | ||||||
|             if (program->Sem().Get(stmt->expr)->ConstantValue() != nullptr) { |  | ||||||
|                 return true; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |     auto& sem = src->Sem(); | ||||||
|     auto& sem = ctx.src->Sem(); |  | ||||||
| 
 | 
 | ||||||
|     std::unordered_map<SinkSignature, Symbol, utils::Hasher<SinkSignature>> sinks; |     utils::Hashmap<SinkSignature, Symbol, 8, utils::Hasher<SinkSignature>> sinks; | ||||||
| 
 | 
 | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     bool made_changes = false; | ||||||
|  |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         Switch( |         Switch( | ||||||
|             node, |             node, | ||||||
|             [&](const ast::AssignmentStatement* stmt) { |             [&](const ast::AssignmentStatement* stmt) { | ||||||
|                 if (stmt->lhs->Is<ast::PhonyExpression>()) { |                 if (stmt->lhs->Is<ast::PhonyExpression>()) { | ||||||
|  |                     made_changes = true; | ||||||
|  | 
 | ||||||
|                     std::vector<const ast::Expression*> side_effects; |                     std::vector<const ast::Expression*> side_effects; | ||||||
|                     if (!ast::TraverseExpressions( |                     if (!ast::TraverseExpressions( | ||||||
|                             stmt->rhs, ctx.dst->Diagnostics(), |                             stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) { | ||||||
|                             [&](const ast::CallExpression* expr) { |  | ||||||
|                                 // ast::CallExpression may map to a function or builtin call
 |                                 // ast::CallExpression may map to a function or builtin call
 | ||||||
|                                 // (both may have side-effects), or a type initializer or
 |                                 // (both may have side-effects), or a type initializer or
 | ||||||
|                                 // type conversion (both do not have side effects).
 |                                 // type conversion (both do not have side effects).
 | ||||||
| @ -100,8 +91,7 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|                         if (auto* call = side_effects[0]->As<ast::CallExpression>()) { |                         if (auto* call = side_effects[0]->As<ast::CallExpression>()) { | ||||||
|                             // Phony assignment with single call side effect.
 |                             // Phony assignment with single call side effect.
 | ||||||
|                             // Replace phony assignment with call.
 |                             // Replace phony assignment with call.
 | ||||||
|                             ctx.Replace(stmt, |                             ctx.Replace(stmt, [&, call] { return b.CallStmt(ctx.Clone(call)); }); | ||||||
|                                         [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); }); |  | ||||||
|                             return; |                             return; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
| @ -114,22 +104,21 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|                         for (auto* arg : side_effects) { |                         for (auto* arg : side_effects) { | ||||||
|                             sig.push_back(sem.Get(arg)->Type()->UnwrapRef()); |                             sig.push_back(sem.Get(arg)->Type()->UnwrapRef()); | ||||||
|                         } |                         } | ||||||
|                         auto sink = utils::GetOrCreate(sinks, sig, [&] { |                         auto sink = sinks.GetOrCreate(sig, [&] { | ||||||
|                             auto name = ctx.dst->Symbols().New("phony_sink"); |                             auto name = b.Symbols().New("phony_sink"); | ||||||
|                             utils::Vector<const ast::Parameter*, 8> params; |                             utils::Vector<const ast::Parameter*, 8> params; | ||||||
|                             for (auto* ty : sig) { |                             for (auto* ty : sig) { | ||||||
|                                 auto* ast_ty = CreateASTTypeFor(ctx, ty); |                                 auto* ast_ty = CreateASTTypeFor(ctx, ty); | ||||||
|                                 params.Push( |                                 params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty)); | ||||||
|                                     ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty)); |  | ||||||
|                             } |                             } | ||||||
|                             ctx.dst->Func(name, params, ctx.dst->ty.void_(), {}); |                             b.Func(name, params, b.ty.void_(), {}); | ||||||
|                             return name; |                             return name; | ||||||
|                         }); |                         }); | ||||||
|                         utils::Vector<const ast::Expression*, 8> args; |                         utils::Vector<const ast::Expression*, 8> args; | ||||||
|                         for (auto* arg : side_effects) { |                         for (auto* arg : side_effects) { | ||||||
|                             args.Push(ctx.Clone(arg)); |                             args.Push(ctx.Clone(arg)); | ||||||
|                         } |                         } | ||||||
|                         return ctx.dst->CallStmt(ctx.dst->Call(sink, args)); |                         return b.CallStmt(b.Call(sink, args)); | ||||||
|                     }); |                     }); | ||||||
|                 } |                 } | ||||||
|             }, |             }, | ||||||
| @ -138,12 +127,18 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
|                 // TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
 |                 // TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
 | ||||||
|                 auto* sem_expr = sem.Get(stmt->expr); |                 auto* sem_expr = sem.Get(stmt->expr); | ||||||
|                 if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) { |                 if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) { | ||||||
|  |                     made_changes = true; | ||||||
|                     ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt); |                     ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt); | ||||||
|                 } |                 } | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     if (!made_changes) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -33,19 +33,10 @@ class RemovePhonies final : public Castable<RemovePhonies, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~RemovePhonies() override; |     ~RemovePhonies() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -36,27 +36,28 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default; | |||||||
| 
 | 
 | ||||||
| RemoveUnreachableStatements::~RemoveUnreachableStatements() = default; | RemoveUnreachableStatements::~RemoveUnreachableStatements() = default; | ||||||
| 
 | 
 | ||||||
| bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src, | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |                                                           const DataMap&, | ||||||
|         if (auto* stmt = program->Sem().Get<sem::Statement>(node)) { |                                                           DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     bool made_changes = false; | ||||||
|  |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|  |         if (auto* stmt = src->Sem().Get<sem::Statement>(node)) { | ||||||
|             if (!stmt->IsReachable()) { |             if (!stmt->IsReachable()) { | ||||||
|                 return true; |                 RemoveStatement(ctx, stmt->Declaration()); | ||||||
|  |                 made_changes = true; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |     if (!made_changes) { | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |         return SkipTransform; | ||||||
|         if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) { |  | ||||||
|             if (!stmt->IsReachable()) { |  | ||||||
|                 RemoveStatement(ctx, stmt->Declaration()); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -32,19 +32,10 @@ class RemoveUnreachableStatements final : public Castable<RemoveUnreachableState | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~RemoveUnreachableStatements() override; |     ~RemoveUnreachableStatements() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -1252,39 +1252,31 @@ Renamer::Config::~Config() = default; | |||||||
| Renamer::Renamer() = default; | Renamer::Renamer() = default; | ||||||
| Renamer::~Renamer() = default; | Renamer::~Renamer() = default; | ||||||
| 
 | 
 | ||||||
| Output Renamer::Run(const Program* in, const DataMap& inputs) const { | Transform::ApplyResult Renamer::Apply(const Program* src, | ||||||
|     ProgramBuilder out; |                                       const DataMap& inputs, | ||||||
|     // Disable auto-cloning of symbols, since we want to rename them.
 |                                       DataMap& outputs) const { | ||||||
|     CloneContext ctx(&out, in, false); |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ false}; | ||||||
| 
 | 
 | ||||||
|     // Swizzles, builtin calls and builtin structure members need to keep their
 |     // Swizzles, builtin calls and builtin structure members need to keep their
 | ||||||
|     // symbols preserved.
 |     // symbols preserved.
 | ||||||
|     std::unordered_set<const ast::IdentifierExpression*> preserve; |     utils::Hashset<const ast::IdentifierExpression*, 8> preserve; | ||||||
|     for (auto* node : in->ASTNodes().Objects()) { |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         if (auto* member = node->As<ast::MemberAccessorExpression>()) { |         if (auto* member = node->As<ast::MemberAccessorExpression>()) { | ||||||
|             auto* sem = in->Sem().Get(member); |             auto* sem = src->Sem().Get(member); | ||||||
|             if (!sem) { |  | ||||||
|                 TINT_ICE(Transform, out.Diagnostics()) |  | ||||||
|                     << "MemberAccessorExpression has no semantic info"; |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|             if (sem->Is<sem::Swizzle>()) { |             if (sem->Is<sem::Swizzle>()) { | ||||||
|                 preserve.emplace(member->member); |                 preserve.Add(member->member); | ||||||
|             } else if (auto* str_expr = in->Sem().Get(member->structure)) { |             } else if (auto* str_expr = src->Sem().Get(member->structure)) { | ||||||
|                 if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) { |                 if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) { | ||||||
|                     if (ty->Declaration() == nullptr) {  // Builtin structure
 |                     if (ty->Declaration() == nullptr) {  // Builtin structure
 | ||||||
|                         preserve.emplace(member->member); |                         preserve.Add(member->member); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } else if (auto* call = node->As<ast::CallExpression>()) { |         } else if (auto* call = node->As<ast::CallExpression>()) { | ||||||
|             auto* sem = in->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>(); |             auto* sem = src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>(); | ||||||
|             if (!sem) { |  | ||||||
|                 TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info"; |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|             if (sem->Target()->Is<sem::Builtin>()) { |             if (sem->Target()->Is<sem::Builtin>()) { | ||||||
|                 preserve.emplace(call->target.name); |                 preserve.Add(call->target.name); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -1300,7 +1292,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.ReplaceAll([&](Symbol sym_in) { |     ctx.ReplaceAll([&](Symbol sym_in) { | ||||||
|         auto name_in = ctx.src->Symbols().NameFor(sym_in); |         auto name_in = src->Symbols().NameFor(sym_in); | ||||||
|         if (preserve_unicode || text::utf8::IsASCII(name_in)) { |         if (preserve_unicode || text::utf8::IsASCII(name_in)) { | ||||||
|             switch (target) { |             switch (target) { | ||||||
|                 case Target::kAll: |                 case Target::kAll: | ||||||
| @ -1343,17 +1335,20 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const { | |||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* { |     ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* { | ||||||
|         if (preserve.count(ident)) { |         if (preserve.Contains(ident)) { | ||||||
|             auto sym_in = ident->symbol; |             auto sym_in = ident->symbol; | ||||||
|             auto str = in->Symbols().NameFor(sym_in); |             auto str = src->Symbols().NameFor(sym_in); | ||||||
|             auto sym_out = out.Symbols().Register(str); |             auto sym_out = b.Symbols().Register(str); | ||||||
|             return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out); |             return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out); | ||||||
|         } |         } | ||||||
|         return nullptr;  // Clone ident. Uses the symbol remapping above.
 |         return nullptr;  // Clone ident. Uses the symbol remapping above.
 | ||||||
|     }); |     }); | ||||||
|     ctx.Clone(); |  | ||||||
| 
 | 
 | ||||||
|     return Output(Program(std::move(out)), std::make_unique<Data>(std::move(remappings))); |     ctx.Clone();  // Must come before the std::move()
 | ||||||
|  | 
 | ||||||
|  |     outputs.Add<Data>(std::move(remappings)); | ||||||
|  | 
 | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -85,11 +85,10 @@ class Renamer final : public Castable<Renamer, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~Renamer() override; |     ~Renamer() override; | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform on `program`, returning the transformation result.
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param program the source program to transform
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @param data optional extra transform-specific input data
 |                       const DataMap& inputs, | ||||||
|     /// @returns the transformation result
 |                       DataMap& outputs) const override; | ||||||
|     Output Run(const Program* program, const DataMap& data = {}) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -33,36 +33,48 @@ using namespace tint::number_suffixes;  // NOLINT | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// State holds the current transform state
 | /// PIMPL state for the transform
 | ||||||
| struct Robustness::State { | struct Robustness::State { | ||||||
|     /// The clone context
 |     /// Constructor
 | ||||||
|     CloneContext& ctx; |     /// @param program the source program
 | ||||||
|  |     /// @param omitted the omitted address spaces
 | ||||||
|  |     State(const Program* program, std::unordered_set<ast::AddressSpace>&& omitted) | ||||||
|  |         : src(program), omitted_address_spaces(std::move(omitted)) {} | ||||||
| 
 | 
 | ||||||
|     /// Set of address spacees to not apply the transform to
 |     /// Runs the transform
 | ||||||
|     std::unordered_set<ast::AddressSpace> omitted_classes; |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
| 
 |     ApplyResult Run() { | ||||||
|     /// Applies the transformation state to `ctx`.
 |  | ||||||
|     void Transform() { |  | ||||||
|         ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); }); |         ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); }); | ||||||
|         ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); }); |         ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); }); | ||||||
|  | 
 | ||||||
|  |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |   private: | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     /// The clone context
 | ||||||
|  |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|  |     /// Set of address spaces to not apply the transform to
 | ||||||
|  |     std::unordered_set<ast::AddressSpace> omitted_address_spaces; | ||||||
|  | 
 | ||||||
|     /// Apply bounds clamping to array, vector and matrix indexing
 |     /// Apply bounds clamping to array, vector and matrix indexing
 | ||||||
|     /// @param expr the array, vector or matrix index expression
 |     /// @param expr the array, vector or matrix index expression
 | ||||||
|     /// @return the clamped replacement expression, or nullptr if `expr` should be cloned without
 |     /// @return the clamped replacement expression, or nullptr if `expr` should be cloned without
 | ||||||
|     /// changes.
 |     /// changes.
 | ||||||
|     const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) { |     const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) { | ||||||
|         auto* sem = |         auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>(); | ||||||
|             ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>(); |  | ||||||
|         auto* ret_type = sem->Type(); |         auto* ret_type = sem->Type(); | ||||||
| 
 | 
 | ||||||
|         auto* ref = ret_type->As<sem::Reference>(); |         auto* ref = ret_type->As<sem::Reference>(); | ||||||
|         if (ref && omitted_classes.count(ref->AddressSpace()) != 0) { |         if (ref && omitted_address_spaces.count(ref->AddressSpace()) != 0) { | ||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         ProgramBuilder& b = *ctx.dst; |  | ||||||
| 
 |  | ||||||
|         // idx return the cloned index expression, as a u32.
 |         // idx return the cloned index expression, as a u32.
 | ||||||
|         auto idx = [&]() -> const ast::Expression* { |         auto idx = [&]() -> const ast::Expression* { | ||||||
|             auto* i = ctx.Clone(expr->index); |             auto* i = ctx.Clone(expr->index); | ||||||
| @ -109,8 +121,8 @@ struct Robustness::State { | |||||||
|                 } else { |                 } else { | ||||||
|                     // Note: Don't be tempted to use the array override variable as an expression
 |                     // Note: Don't be tempted to use the array override variable as an expression
 | ||||||
|                     // here, the name might be shadowed!
 |                     // here, the name might be shadowed!
 | ||||||
|                     ctx.dst->Diagnostics().add_error(diag::System::Transform, |                     b.Diagnostics().add_error(diag::System::Transform, | ||||||
|                                                      sem::Array::kErrExpectedConstantCount); |                                               sem::Array::kErrExpectedConstantCount); | ||||||
|                     return nullptr; |                     return nullptr; | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
| @ -119,7 +131,7 @@ struct Robustness::State { | |||||||
|             [&](Default) { |             [&](Default) { | ||||||
|                 TINT_ICE(Transform, b.Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                     << "unhandled object type in robustness of array index: " |                     << "unhandled object type in robustness of array index: " | ||||||
|                     << ctx.src->FriendlyName(ret_type->UnwrapRef()); |                     << src->FriendlyName(ret_type->UnwrapRef()); | ||||||
|                 return nullptr; |                 return nullptr; | ||||||
|             }); |             }); | ||||||
| 
 | 
 | ||||||
| @ -127,9 +139,9 @@ struct Robustness::State { | |||||||
|             return nullptr;  // Clamping not needed
 |             return nullptr;  // Clamping not needed
 | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         auto src = ctx.Clone(expr->source); |         auto idx_src = ctx.Clone(expr->source); | ||||||
|         auto* obj = ctx.Clone(expr->object); |         auto* idx_obj = ctx.Clone(expr->object); | ||||||
|         return b.IndexAccessor(src, obj, clamped_idx); |         return b.IndexAccessor(idx_src, idx_obj, clamped_idx); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// @param type builtin type
 |     /// @param type builtin type
 | ||||||
| @ -145,15 +157,13 @@ struct Robustness::State { | |||||||
|     /// @return the clamped replacement call expression, or nullptr if `expr`
 |     /// @return the clamped replacement call expression, or nullptr if `expr`
 | ||||||
|     /// should be cloned without changes.
 |     /// should be cloned without changes.
 | ||||||
|     const ast::CallExpression* Transform(const ast::CallExpression* expr) { |     const ast::CallExpression* Transform(const ast::CallExpression* expr) { | ||||||
|         auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); |         auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||||
|         auto* call_target = call->Target(); |         auto* call_target = call->Target(); | ||||||
|         auto* builtin = call_target->As<sem::Builtin>(); |         auto* builtin = call_target->As<sem::Builtin>(); | ||||||
|         if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) { |         if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) { | ||||||
|             return nullptr;  // No transform, just clone.
 |             return nullptr;  // No transform, just clone.
 | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         ProgramBuilder& b = *ctx.dst; |  | ||||||
| 
 |  | ||||||
|         // Indices of the mandatory texture and coords parameters, and the optional
 |         // Indices of the mandatory texture and coords parameters, and the optional
 | ||||||
|         // array and level parameters.
 |         // array and level parameters.
 | ||||||
|         auto& signature = builtin->Signature(); |         auto& signature = builtin->Signature(); | ||||||
| @ -261,7 +271,7 @@ struct Robustness::State { | |||||||
|         // Clamp the level argument, if provided
 |         // Clamp the level argument, if provided
 | ||||||
|         if (level_idx >= 0) { |         if (level_idx >= 0) { | ||||||
|             auto* arg = expr->args[static_cast<size_t>(level_idx)]; |             auto* arg = expr->args[static_cast<size_t>(level_idx)]; | ||||||
|             ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0_a)); |             ctx.Replace(arg, level_arg ? level_arg() : b.Expr(0_a)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return nullptr;  // Clone, which will use the argument replacements above.
 |         return nullptr;  // Clone, which will use the argument replacements above.
 | ||||||
| @ -276,28 +286,27 @@ Robustness::Config& Robustness::Config::operator=(const Config&) = default; | |||||||
| Robustness::Robustness() = default; | Robustness::Robustness() = default; | ||||||
| Robustness::~Robustness() = default; | Robustness::~Robustness() = default; | ||||||
| 
 | 
 | ||||||
| void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | Transform::ApplyResult Robustness::Apply(const Program* src, | ||||||
|  |                                          const DataMap& inputs, | ||||||
|  |                                          DataMap&) const { | ||||||
|     Config cfg; |     Config cfg; | ||||||
|     if (auto* cfg_data = inputs.Get<Config>()) { |     if (auto* cfg_data = inputs.Get<Config>()) { | ||||||
|         cfg = *cfg_data; |         cfg = *cfg_data; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     std::unordered_set<ast::AddressSpace> omitted_classes; |     std::unordered_set<ast::AddressSpace> omitted_address_spaces; | ||||||
|     for (auto sc : cfg.omitted_classes) { |     for (auto sc : cfg.omitted_address_spaces) { | ||||||
|         switch (sc) { |         switch (sc) { | ||||||
|             case AddressSpace::kUniform: |             case AddressSpace::kUniform: | ||||||
|                 omitted_classes.insert(ast::AddressSpace::kUniform); |                 omitted_address_spaces.insert(ast::AddressSpace::kUniform); | ||||||
|                 break; |                 break; | ||||||
|             case AddressSpace::kStorage: |             case AddressSpace::kStorage: | ||||||
|                 omitted_classes.insert(ast::AddressSpace::kStorage); |                 omitted_address_spaces.insert(ast::AddressSpace::kStorage); | ||||||
|                 break; |                 break; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     State state{ctx, std::move(omitted_classes)}; |     return State{src, std::move(omitted_address_spaces)}.Run(); | ||||||
| 
 |  | ||||||
|     state.Transform(); |  | ||||||
|     ctx.Clone(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -54,9 +54,9 @@ class Robustness final : public Castable<Robustness, Transform> { | |||||||
|         /// @returns this Config
 |         /// @returns this Config
 | ||||||
|         Config& operator=(const Config&); |         Config& operator=(const Config&); | ||||||
| 
 | 
 | ||||||
|         /// Address spacees to omit from apply the transform to.
 |         /// Address spaces to omit from apply the transform to.
 | ||||||
|         /// This allows for optimizing on hardware that provide safe accesses.
 |         /// This allows for optimizing on hardware that provide safe accesses.
 | ||||||
|         std::unordered_set<AddressSpace> omitted_classes; |         std::unordered_set<AddressSpace> omitted_address_spaces; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
| @ -64,14 +64,10 @@ class Robustness final : public Castable<Robustness, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~Robustness() override; |     ~Robustness() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     struct State; |     struct State; | ||||||
|  | |||||||
| @ -1274,7 +1274,7 @@ fn f() { | |||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     Robustness::Config cfg; |     Robustness::Config cfg; | ||||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage); |     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage); | ||||||
| 
 | 
 | ||||||
|     DataMap data; |     DataMap data; | ||||||
|     data.Add<Robustness::Config>(cfg); |     data.Add<Robustness::Config>(cfg); | ||||||
| @ -1325,7 +1325,7 @@ fn f() { | |||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     Robustness::Config cfg; |     Robustness::Config cfg; | ||||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform); |     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform); | ||||||
| 
 | 
 | ||||||
|     DataMap data; |     DataMap data; | ||||||
|     data.Add<Robustness::Config>(cfg); |     data.Add<Robustness::Config>(cfg); | ||||||
| @ -1376,8 +1376,8 @@ fn f() { | |||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|     Robustness::Config cfg; |     Robustness::Config cfg; | ||||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage); |     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage); | ||||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform); |     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform); | ||||||
| 
 | 
 | ||||||
|     DataMap data; |     DataMap data; | ||||||
|     data.Add<Robustness::Config>(cfg); |     data.Add<Robustness::Config>(cfg); | ||||||
|  | |||||||
| @ -45,14 +45,18 @@ struct PointerOp { | |||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for the SimplifyPointers transform
 | /// PIMPL state for the transform
 | ||||||
| struct SimplifyPointers::State { | struct SimplifyPointers::State { | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
| 
 | 
 | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param context the clone context
 |     /// @param program the source program
 | ||||||
|     explicit State(CloneContext& context) : ctx(context) {} |     explicit State(const Program* program) : src(program) {} | ||||||
| 
 | 
 | ||||||
|     /// Traverses the expression `expr` looking for non-literal array indexing
 |     /// Traverses the expression `expr` looking for non-literal array indexing
 | ||||||
|     /// expressions that would affect the computed address of a pointer
 |     /// expressions that would affect the computed address of a pointer
 | ||||||
| @ -120,10 +124,11 @@ struct SimplifyPointers::State { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Performs the transformation
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|         // A map of saved expressions to their saved variable name
 |         // A map of saved expressions to their saved variable name
 | ||||||
|         std::unordered_map<const ast::Expression*, Symbol> saved_vars; |         utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars; | ||||||
| 
 | 
 | ||||||
|         // Register the ast::Expression transform handler.
 |         // Register the ast::Expression transform handler.
 | ||||||
|         // This performs two different transformations:
 |         // This performs two different transformations:
 | ||||||
| @ -135,9 +140,8 @@ struct SimplifyPointers::State { | |||||||
|         // variable identifier.
 |         // variable identifier.
 | ||||||
|         ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { |         ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { | ||||||
|             // Look to see if we need to swap this Expression with a saved variable.
 |             // Look to see if we need to swap this Expression with a saved variable.
 | ||||||
|             auto it = saved_vars.find(expr); |             if (auto* saved_var = saved_vars.Find(expr)) { | ||||||
|             if (it != saved_vars.end()) { |                 return ctx.dst->Expr(*saved_var); | ||||||
|                 return ctx.dst->Expr(it->second); |  | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // Reduce the expression, folding away chains of address-of / indirections
 |             // Reduce the expression, folding away chains of address-of / indirections
 | ||||||
| @ -174,7 +178,7 @@ struct SimplifyPointers::State { | |||||||
| 
 | 
 | ||||||
|                 // Scan the initializer expression for array index expressions that need
 |                 // Scan the initializer expression for array index expressions that need
 | ||||||
|                 // to be hoist to temporary "saved" variables.
 |                 // to be hoist to temporary "saved" variables.
 | ||||||
|                 std::vector<const ast::VariableDeclStatement*> saved; |                 utils::Vector<const ast::VariableDeclStatement*, 8> saved; | ||||||
|                 CollectSavedArrayIndices( |                 CollectSavedArrayIndices( | ||||||
|                     var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { |                     var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { | ||||||
|                         // We have a sub-expression that needs to be saved.
 |                         // We have a sub-expression that needs to be saved.
 | ||||||
| @ -182,18 +186,18 @@ struct SimplifyPointers::State { | |||||||
|                         auto saved_name = ctx.dst->Symbols().New( |                         auto saved_name = ctx.dst->Symbols().New( | ||||||
|                             ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); |                             ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); | ||||||
|                         auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); |                         auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); | ||||||
|                         saved.emplace_back(decl); |                         saved.Push(decl); | ||||||
|                         // Record the substitution of `idx_expr` to the saved variable
 |                         // Record the substitution of `idx_expr` to the saved variable
 | ||||||
|                         // with the symbol `saved_name`. This will be used by the
 |                         // with the symbol `saved_name`. This will be used by the
 | ||||||
|                         // ReplaceAll() handler above.
 |                         // ReplaceAll() handler above.
 | ||||||
|                         saved_vars.emplace(idx_expr, saved_name); |                         saved_vars.Add(idx_expr, saved_name); | ||||||
|                     }); |                     }); | ||||||
| 
 | 
 | ||||||
|                 // Find the place to insert the saved declarations.
 |                 // Find the place to insert the saved declarations.
 | ||||||
|                 // Special care needs to be made for lets declared as the initializer
 |                 // Special care needs to be made for lets declared as the initializer
 | ||||||
|                 // part of for-loops. In this case the block will hold the for-loop
 |                 // part of for-loops. In this case the block will hold the for-loop
 | ||||||
|                 // statement, not the let.
 |                 // statement, not the let.
 | ||||||
|                 if (!saved.empty()) { |                 if (!saved.IsEmpty()) { | ||||||
|                     auto* stmt = ctx.src->Sem().Get(let); |                     auto* stmt = ctx.src->Sem().Get(let); | ||||||
|                     auto* block = stmt->Block(); |                     auto* block = stmt->Block(); | ||||||
|                     // Find the statement owned by the block (either the let decl or a
 |                     // Find the statement owned by the block (either the let decl or a
 | ||||||
| @ -219,7 +223,9 @@ struct SimplifyPointers::State { | |||||||
|                 RemoveStatement(ctx, let); |                 RemoveStatement(ctx, let); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| @ -227,8 +233,8 @@ SimplifyPointers::SimplifyPointers() = default; | |||||||
| 
 | 
 | ||||||
| SimplifyPointers::~SimplifyPointers() = default; | SimplifyPointers::~SimplifyPointers() = default; | ||||||
| 
 | 
 | ||||||
| void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     State(ctx).Run(); |     return State(src).Run(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -39,16 +39,13 @@ class SimplifyPointers final : public Castable<SimplifyPointers, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~SimplifyPointers() override; |     ~SimplifyPointers() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     struct State; |     ApplyResult Apply(const Program* program, | ||||||
|  |                       const DataMap& inputs, | ||||||
|  |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |   private: | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |     struct State; | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -30,33 +30,37 @@ SingleEntryPoint::SingleEntryPoint() = default; | |||||||
| 
 | 
 | ||||||
| SingleEntryPoint::~SingleEntryPoint() = default; | SingleEntryPoint::~SingleEntryPoint() = default; | ||||||
| 
 | 
 | ||||||
| void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, | ||||||
|  |                                                const DataMap& inputs, | ||||||
|  |                                                DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     auto* cfg = inputs.Get<Config>(); |     auto* cfg = inputs.Get<Config>(); | ||||||
|     if (cfg == nullptr) { |     if (cfg == nullptr) { | ||||||
|         ctx.dst->Diagnostics().add_error( |         b.Diagnostics().add_error(diag::System::Transform, | ||||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||||
| 
 |         return Program(std::move(b)); | ||||||
|         return; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Find the target entry point.
 |     // Find the target entry point.
 | ||||||
|     const ast::Function* entry_point = nullptr; |     const ast::Function* entry_point = nullptr; | ||||||
|     for (auto* f : ctx.src->AST().Functions()) { |     for (auto* f : src->AST().Functions()) { | ||||||
|         if (!f->IsEntryPoint()) { |         if (!f->IsEntryPoint()) { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|         if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) { |         if (src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) { | ||||||
|             entry_point = f; |             entry_point = f; | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     if (entry_point == nullptr) { |     if (entry_point == nullptr) { | ||||||
|         ctx.dst->Diagnostics().add_error(diag::System::Transform, |         b.Diagnostics().add_error(diag::System::Transform, | ||||||
|                                          "entry point '" + cfg->entry_point_name + "' not found"); |                                   "entry point '" + cfg->entry_point_name + "' not found"); | ||||||
|         return; |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     auto& sem = ctx.src->Sem(); |     auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|     // Build set of referenced module-scope variables for faster lookups later.
 |     // Build set of referenced module-scope variables for faster lookups later.
 | ||||||
|     std::unordered_set<const ast::Variable*> referenced_vars; |     std::unordered_set<const ast::Variable*> referenced_vars; | ||||||
| @ -66,12 +70,12 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c | |||||||
| 
 | 
 | ||||||
|     // Clone any module-scope variables, types, and functions that are statically referenced by the
 |     // Clone any module-scope variables, types, and functions that are statically referenced by the
 | ||||||
|     // target entry point.
 |     // target entry point.
 | ||||||
|     for (auto* decl : ctx.src->AST().GlobalDeclarations()) { |     for (auto* decl : src->AST().GlobalDeclarations()) { | ||||||
|         Switch( |         Switch( | ||||||
|             decl,  //
 |             decl,  //
 | ||||||
|             [&](const ast::TypeDecl* ty) { |             [&](const ast::TypeDecl* ty) { | ||||||
|                 // TODO(jrprice): Strip unused types.
 |                 // TODO(jrprice): Strip unused types.
 | ||||||
|                 ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); |                 b.AST().AddTypeDecl(ctx.Clone(ty)); | ||||||
|             }, |             }, | ||||||
|             [&](const ast::Override* override) { |             [&](const ast::Override* override) { | ||||||
|                 if (referenced_vars.count(override)) { |                 if (referenced_vars.count(override)) { | ||||||
| @ -80,37 +84,39 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c | |||||||
|                         // so that its allocated ID so that it won't be affected by other
 |                         // so that its allocated ID so that it won't be affected by other
 | ||||||
|                         // stripped away overrides
 |                         // stripped away overrides
 | ||||||
|                         auto* global = sem.Get(override); |                         auto* global = sem.Get(override); | ||||||
|                         const auto* id = ctx.dst->Id(global->OverrideId()); |                         const auto* id = b.Id(global->OverrideId()); | ||||||
|                         ctx.InsertFront(override->attributes, id); |                         ctx.InsertFront(override->attributes, id); | ||||||
|                     } |                     } | ||||||
|                     ctx.dst->AST().AddGlobalVariable(ctx.Clone(override)); |                     b.AST().AddGlobalVariable(ctx.Clone(override)); | ||||||
|                 } |                 } | ||||||
|             }, |             }, | ||||||
|             [&](const ast::Var* var) { |             [&](const ast::Var* var) { | ||||||
|                 if (referenced_vars.count(var)) { |                 if (referenced_vars.count(var)) { | ||||||
|                     ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); |                     b.AST().AddGlobalVariable(ctx.Clone(var)); | ||||||
|                 } |                 } | ||||||
|             }, |             }, | ||||||
|             [&](const ast::Const* c) { |             [&](const ast::Const* c) { | ||||||
|                 // Always keep 'const' declarations, as these can be used by attributes and array
 |                 // Always keep 'const' declarations, as these can be used by attributes and array
 | ||||||
|                 // sizes, which are not tracked as transitively used by functions. They also don't
 |                 // sizes, which are not tracked as transitively used by functions. They also don't
 | ||||||
|                 // typically get emitted by the backend unless they're actually used.
 |                 // typically get emitted by the backend unless they're actually used.
 | ||||||
|                 ctx.dst->AST().AddGlobalVariable(ctx.Clone(c)); |                 b.AST().AddGlobalVariable(ctx.Clone(c)); | ||||||
|             }, |             }, | ||||||
|             [&](const ast::Function* func) { |             [&](const ast::Function* func) { | ||||||
|                 if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { |                 if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { | ||||||
|                     ctx.dst->AST().AddFunction(ctx.Clone(func)); |                     b.AST().AddFunction(ctx.Clone(func)); | ||||||
|                 } |                 } | ||||||
|             }, |             }, | ||||||
|             [&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); }, |             [&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); }, | ||||||
|             [&](Default) { |             [&](Default) { | ||||||
|                 TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) |                 TINT_UNREACHABLE(Transform, b.Diagnostics()) | ||||||
|                     << "unhandled global declaration: " << decl->TypeInfo().name; |                     << "unhandled global declaration: " << decl->TypeInfo().name; | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Clone the entry point.
 |     // Clone the entry point.
 | ||||||
|     ctx.dst->AST().AddFunction(ctx.Clone(entry_point)); |     b.AST().AddFunction(ctx.Clone(entry_point)); | ||||||
|  | 
 | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {} | SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {} | ||||||
|  | |||||||
| @ -53,14 +53,10 @@ class SingleEntryPoint final : public Castable<SingleEntryPoint, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~SingleEntryPoint() override; |     ~SingleEntryPoint() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -37,7 +37,7 @@ namespace tint::transform { | |||||||
| 
 | 
 | ||||||
| using namespace tint::number_suffixes;  // NOLINT
 | using namespace tint::number_suffixes;  // NOLINT
 | ||||||
| 
 | 
 | ||||||
| /// Private implementation of transform
 | /// PIMPL state for the transform
 | ||||||
| struct SpirvAtomic::State { | struct SpirvAtomic::State { | ||||||
|   private: |   private: | ||||||
|     /// A struct that has been forked because a subset of members were made atomic.
 |     /// A struct that has been forked because a subset of members were made atomic.
 | ||||||
| @ -46,19 +46,24 @@ struct SpirvAtomic::State { | |||||||
|         std::unordered_set<size_t> atomic_members; |         std::unordered_set<size_t> atomic_members; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     CloneContext& ctx; |     /// The source program
 | ||||||
|     ProgramBuilder& b = *ctx.dst; |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     /// The clone context
 | ||||||
|  |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|     std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs; |     std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs; | ||||||
|     std::unordered_set<const sem::Variable*> atomic_variables; |     std::unordered_set<const sem::Variable*> atomic_variables; | ||||||
|     utils::UniqueVector<const sem::Expression*, 8> atomic_expressions; |     utils::UniqueVector<const sem::Expression*, 8> atomic_expressions; | ||||||
| 
 | 
 | ||||||
|   public: |   public: | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param c the clone context
 |     /// @param program the source program
 | ||||||
|     explicit State(CloneContext& c) : ctx(c) {} |     explicit State(const Program* program) : src(program) {} | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|         // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
 |         // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
 | ||||||
|         // for atomic builtin calls.
 |         // for atomic builtin calls.
 | ||||||
|         for (auto* fn : ctx.src->AST().Functions()) { |         for (auto* fn : ctx.src->AST().Functions()) { | ||||||
| @ -102,6 +107,10 @@ struct SpirvAtomic::State { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         if (atomic_expressions.IsEmpty()) { | ||||||
|  |             return SkipTransform; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         // Transform all variables and structure members that were used in atomic operations as
 |         // Transform all variables and structure members that were used in atomic operations as
 | ||||||
|         // atomic types. This propagates up originating expression chains.
 |         // atomic types. This propagates up originating expression chains.
 | ||||||
|         ProcessAtomicExpressions(); |         ProcessAtomicExpressions(); | ||||||
| @ -143,6 +152,7 @@ struct SpirvAtomic::State { | |||||||
|         ReplaceLoadsAndStores(); |         ReplaceLoadsAndStores(); | ||||||
| 
 | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
| @ -297,17 +307,8 @@ const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const { | |||||||
|                                                           ctx->dst->AllocateNodeID(), builtin); |                                                           ctx->dst->AllocateNodeID(), builtin); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool SpirvAtomic::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     for (auto* fn : program->AST().Functions()) { |     return State{src}.Run(); | ||||||
|         if (ast::HasAttribute<Stub>(fn->attributes)) { |  | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void SpirvAtomic::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     State{ctx}.Run(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -63,21 +63,13 @@ class SpirvAtomic final : public Castable<SpirvAtomic, Transform> { | |||||||
|         const sem::BuiltinType builtin; |         const sem::BuiltinType builtin; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   protected: |   private: | ||||||
|     struct State; |     struct State; | ||||||
| 
 |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -77,14 +77,20 @@ struct Hasher<DynamicIndex> { | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for the Std140 transform
 | /// PIMPL state for the transform
 | ||||||
| struct Std140::State { | struct Std140::State { | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param c the CloneContext
 |     /// @param program the source program
 | ||||||
|     explicit State(CloneContext& c) : ctx(c) {} |     explicit State(const Program* program) : src(program) {} | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|  |         if (!ShouldRun()) { | ||||||
|  |             // Transform is not required
 | ||||||
|  |             return SkipTransform; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         // Begin by creating forked types for any type that is used as a uniform buffer, that
 |         // Begin by creating forked types for any type that is used as a uniform buffer, that
 | ||||||
|         // either directly or transitively contains a matrix that needs splitting for std140 layout.
 |         // either directly or transitively contains a matrix that needs splitting for std140 layout.
 | ||||||
|         ForkTypes(); |         ForkTypes(); | ||||||
| @ -116,11 +122,11 @@ struct Std140::State { | |||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// @returns true if this transform should be run for the given program
 |     /// @returns true if this transform should be run for the given program
 | ||||||
|     /// @param program the program to inspect
 |     bool ShouldRun() const { | ||||||
|     static bool ShouldRun(const Program* program) { |  | ||||||
|         // Returns true if the type needs to be forked for std140 usage.
 |         // Returns true if the type needs to be forked for std140 usage.
 | ||||||
|         auto needs_fork = [&](const sem::Type* ty) { |         auto needs_fork = [&](const sem::Type* ty) { | ||||||
|             while (auto* arr = ty->As<sem::Array>()) { |             while (auto* arr = ty->As<sem::Array>()) { | ||||||
| @ -135,7 +141,7 @@ struct Std140::State { | |||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Scan structures for members that need forking
 |         // Scan structures for members that need forking
 | ||||||
|         for (auto* ty : program->Types()) { |         for (auto* ty : src->Types()) { | ||||||
|             if (auto* str = ty->As<sem::Struct>()) { |             if (auto* str = ty->As<sem::Struct>()) { | ||||||
|                 if (str->UsedAs(ast::AddressSpace::kUniform)) { |                 if (str->UsedAs(ast::AddressSpace::kUniform)) { | ||||||
|                     for (auto* member : str->Members()) { |                     for (auto* member : str->Members()) { | ||||||
| @ -148,8 +154,8 @@ struct Std140::State { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // Scan uniform variables that have types that need forking
 |         // Scan uniform variables that have types that need forking
 | ||||||
|         for (auto* decl : program->AST().GlobalVariables()) { |         for (auto* decl : src->AST().GlobalVariables()) { | ||||||
|             auto* global = program->Sem().Get(decl); |             auto* global = src->Sem().Get(decl); | ||||||
|             if (global->AddressSpace() == ast::AddressSpace::kUniform) { |             if (global->AddressSpace() == ast::AddressSpace::kUniform) { | ||||||
|                 if (needs_fork(global->Type()->UnwrapRef())) { |                 if (needs_fork(global->Type()->UnwrapRef())) { | ||||||
|                     return true; |                     return true; | ||||||
| @ -197,14 +203,16 @@ struct Std140::State { | |||||||
|         } |         } | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|     /// Alias to the semantic info in ctx.src
 |     /// Alias to the semantic info in src
 | ||||||
|     const sem::Info& sem = ctx.src->Sem(); |     const sem::Info& sem = src->Sem(); | ||||||
|     /// Alias to the symbols in ctx.src
 |     /// Alias to the symbols in src
 | ||||||
|     const SymbolTable& sym = ctx.src->Symbols(); |     const SymbolTable& sym = src->Symbols(); | ||||||
|     /// Alias to the ctx.dst program builder
 |  | ||||||
|     ProgramBuilder& b = *ctx.dst; |  | ||||||
| 
 | 
 | ||||||
|     /// Map of load function signature, to the generated function
 |     /// Map of load function signature, to the generated function
 | ||||||
|     utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns; |     utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns; | ||||||
| @ -218,7 +226,7 @@ struct Std140::State { | |||||||
|     // Map of original structure to 'std140' forked structure
 |     // Map of original structure to 'std140' forked structure
 | ||||||
|     utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs; |     utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs; | ||||||
| 
 | 
 | ||||||
|     // Map of structure member in ctx.src of a matrix type, to list of decomposed column
 |     // Map of structure member in src of a matrix type, to list of decomposed column
 | ||||||
|     // members in ctx.dst.
 |     // members in ctx.dst.
 | ||||||
|     utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8> |     utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8> | ||||||
|         std140_mat_members; |         std140_mat_members; | ||||||
| @ -232,7 +240,7 @@ struct Std140::State { | |||||||
|         utils::Vector<Symbol, 4> columns; |         utils::Vector<Symbol, 4> columns; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     // Map of matrix type in ctx.src, to decomposed column structure in ctx.dst.
 |     // Map of matrix type in src, to decomposed column structure in ctx.dst.
 | ||||||
|     utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats; |     utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats; | ||||||
| 
 | 
 | ||||||
|     /// AccessChain describes a chain of access expressions to uniform buffer variable.
 |     /// AccessChain describes a chain of access expressions to uniform buffer variable.
 | ||||||
| @ -266,7 +274,7 @@ struct Std140::State { | |||||||
|     /// map (via Std140Type()).
 |     /// map (via Std140Type()).
 | ||||||
|     void ForkTypes() { |     void ForkTypes() { | ||||||
|         // For each module scope declaration...
 |         // For each module scope declaration...
 | ||||||
|         for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) { |         for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) { | ||||||
|             // Check to see if this is a structure used by a uniform buffer...
 |             // Check to see if this is a structure used by a uniform buffer...
 | ||||||
|             auto* str = sem.Get<sem::Struct>(global); |             auto* str = sem.Get<sem::Struct>(global); | ||||||
|             if (str && str->UsedAs(ast::AddressSpace::kUniform)) { |             if (str && str->UsedAs(ast::AddressSpace::kUniform)) { | ||||||
| @ -317,7 +325,7 @@ struct Std140::State { | |||||||
|                 if (fork_std140) { |                 if (fork_std140) { | ||||||
|                     // Clone any members that have not already been cloned.
 |                     // Clone any members that have not already been cloned.
 | ||||||
|                     for (auto& member : members) { |                     for (auto& member : members) { | ||||||
|                         if (member->program_id == ctx.src->ID()) { |                         if (member->program_id == src->ID()) { | ||||||
|                             member = ctx.Clone(member); |                             member = ctx.Clone(member); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
| @ -326,7 +334,7 @@ struct Std140::State { | |||||||
|                     auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140"); |                     auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140"); | ||||||
|                     auto* std140 = b.create<ast::Struct>(name, std::move(members), |                     auto* std140 = b.create<ast::Struct>(name, std::move(members), | ||||||
|                                                          ctx.Clone(str->Declaration()->attributes)); |                                                          ctx.Clone(str->Declaration()->attributes)); | ||||||
|                     ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140); |                     ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140); | ||||||
|                     std140_structs.Add(str, name); |                     std140_structs.Add(str, name); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @ -337,14 +345,13 @@ struct Std140::State { | |||||||
|     /// type that has been forked for std140-layout.
 |     /// type that has been forked for std140-layout.
 | ||||||
|     /// Populates the #std140_uniforms set.
 |     /// Populates the #std140_uniforms set.
 | ||||||
|     void ReplaceUniformVarTypes() { |     void ReplaceUniformVarTypes() { | ||||||
|         for (auto* global : ctx.src->AST().GlobalVariables()) { |         for (auto* global : src->AST().GlobalVariables()) { | ||||||
|             if (auto* var = global->As<ast::Var>()) { |             if (auto* var = global->As<ast::Var>()) { | ||||||
|                 if (var->declared_address_space == ast::AddressSpace::kUniform) { |                 if (var->declared_address_space == ast::AddressSpace::kUniform) { | ||||||
|                     auto* v = sem.Get(var); |                     auto* v = sem.Get(var); | ||||||
|                     if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) { |                     if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) { | ||||||
|                         ctx.Replace(global->type, std140_ty); |                         ctx.Replace(global->type, std140_ty); | ||||||
|                         std140_uniforms.Add(v); |                         std140_uniforms.Add(v); | ||||||
|                         continue; |  | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @ -404,7 +411,7 @@ struct Std140::State { | |||||||
|                     auto std140_mat = std140_mats.GetOrCreate(mat, [&] { |                     auto std140_mat = std140_mats.GetOrCreate(mat, [&] { | ||||||
|                         auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" + |                         auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" + | ||||||
|                                                     std::to_string(mat->rows()) + "_" + |                                                     std::to_string(mat->rows()) + "_" + | ||||||
|                                                     ctx.src->FriendlyName(mat->type())); |                                                     src->FriendlyName(mat->type())); | ||||||
|                         auto members = |                         auto members = | ||||||
|                             DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size()); |                             DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size()); | ||||||
|                         b.Structure(name, members); |                         b.Structure(name, members); | ||||||
| @ -421,7 +428,7 @@ struct Std140::State { | |||||||
|                 if (auto* std140 = Std140Type(arr->ElemType())) { |                 if (auto* std140 = Std140Type(arr->ElemType())) { | ||||||
|                     utils::Vector<const ast::Attribute*, 1> attrs; |                     utils::Vector<const ast::Attribute*, 1> attrs; | ||||||
|                     if (!arr->IsStrideImplicit()) { |                     if (!arr->IsStrideImplicit()) { | ||||||
|                         attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride())); |                         attrs.Push(b.create<ast::StrideAttribute>(arr->Stride())); | ||||||
|                     } |                     } | ||||||
|                     auto count = arr->ConstantCount(); |                     auto count = arr->ConstantCount(); | ||||||
|                     if (!count) { |                     if (!count) { | ||||||
| @ -429,7 +436,7 @@ struct Std140::State { | |||||||
|                         // * Override-expression counts can only be applied to workgroup arrays, and
 |                         // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||||
|                         //   this method only handles types transitively used as uniform buffers.
 |                         //   this method only handles types transitively used as uniform buffers.
 | ||||||
|                         // * Runtime-sized arrays cannot be used in uniform buffers.
 |                         // * Runtime-sized arrays cannot be used in uniform buffers.
 | ||||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) |                         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                             << "unexpected non-constant array count"; |                             << "unexpected non-constant array count"; | ||||||
|                         count = 1; |                         count = 1; | ||||||
|                     } |                     } | ||||||
| @ -440,7 +447,7 @@ struct Std140::State { | |||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// @param mat the matrix to decompose (in ctx.src)
 |     /// @param mat the matrix to decompose (in src)
 | ||||||
|     /// @param name_prefix the name prefix to apply to each of the returned column vector members.
 |     /// @param name_prefix the name prefix to apply to each of the returned column vector members.
 | ||||||
|     /// @param align the alignment in bytes of the matrix.
 |     /// @param align the alignment in bytes of the matrix.
 | ||||||
|     /// @param size the size in bytes of the matrix.
 |     /// @param size the size in bytes of the matrix.
 | ||||||
| @ -473,7 +480,7 @@ struct Std140::State { | |||||||
|             // Build the member
 |             // Build the member
 | ||||||
|             const auto col_name = name_prefix + std::to_string(i); |             const auto col_name = name_prefix + std::to_string(i); | ||||||
|             const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType()); |             const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType()); | ||||||
|             const auto* col_member = ctx.dst->Member(col_name, col_ty, std::move(attributes)); |             const auto* col_member = b.Member(col_name, col_ty, std::move(attributes)); | ||||||
|             // Record the member for std140_mat_members
 |             // Record the member for std140_mat_members
 | ||||||
|             out.Push(col_member); |             out.Push(col_member); | ||||||
|         } |         } | ||||||
| @ -618,7 +625,7 @@ struct Std140::State { | |||||||
| 
 | 
 | ||||||
|     /// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
 |     /// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
 | ||||||
|     ///          being converted.
 |     ///          being converted.
 | ||||||
|     const std::string ConvertSuffix(const sem::Type* ty) const { |     const std::string ConvertSuffix(const sem::Type* ty) { | ||||||
|         return Switch( |         return Switch( | ||||||
|             ty,  //
 |             ty,  //
 | ||||||
|             [&](const sem::Struct* str) { return sym.NameFor(str->Name()); }, |             [&](const sem::Struct* str) { return sym.NameFor(str->Name()); }, | ||||||
| @ -629,8 +636,7 @@ struct Std140::State { | |||||||
|                     // * Override-expression counts can only be applied to workgroup arrays, and
 |                     // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||||
|                     //   this method only handles types transitively used as uniform buffers.
 |                     //   this method only handles types transitively used as uniform buffers.
 | ||||||
|                     // * Runtime-sized arrays cannot be used in uniform buffers.
 |                     // * Runtime-sized arrays cannot be used in uniform buffers.
 | ||||||
|                     TINT_ICE(Transform, ctx.dst->Diagnostics()) |                     TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; | ||||||
|                         << "unexpected non-constant array count"; |  | ||||||
|                     count = 1; |                     count = 1; | ||||||
|                 } |                 } | ||||||
|                 return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType()); |                 return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType()); | ||||||
| @ -642,7 +648,7 @@ struct Std140::State { | |||||||
|             [&](const sem::F32*) { return "f32"; }, |             [&](const sem::F32*) { return "f32"; }, | ||||||
|             [&](Default) { |             [&](Default) { | ||||||
|                 TINT_ICE(Transform, b.Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                     << "unhandled type for conversion name: " << ctx.src->FriendlyName(ty); |                     << "unhandled type for conversion name: " << src->FriendlyName(ty); | ||||||
|                 return ""; |                 return ""; | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| @ -718,8 +724,7 @@ struct Std140::State { | |||||||
|                         stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args)))); |                         stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args)))); | ||||||
|                     } else { |                     } else { | ||||||
|                         TINT_ICE(Transform, b.Diagnostics()) |                         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                             << "failed to find std140 matrix info for: " |                             << "failed to find std140 matrix info for: " << src->FriendlyName(ty); | ||||||
|                             << ctx.src->FriendlyName(ty); |  | ||||||
|                     } |                     } | ||||||
|                 },  //
 |                 },  //
 | ||||||
|                 [&](const sem::Array* arr) { |                 [&](const sem::Array* arr) { | ||||||
| @ -736,7 +741,7 @@ struct Std140::State { | |||||||
|                         // * Override-expression counts can only be applied to workgroup arrays, and
 |                         // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||||
|                         //   this method only handles types transitively used as uniform buffers.
 |                         //   this method only handles types transitively used as uniform buffers.
 | ||||||
|                         // * Runtime-sized arrays cannot be used in uniform buffers.
 |                         // * Runtime-sized arrays cannot be used in uniform buffers.
 | ||||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) |                         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                             << "unexpected non-constant array count"; |                             << "unexpected non-constant array count"; | ||||||
|                         count = 1; |                         count = 1; | ||||||
|                     } |                     } | ||||||
| @ -749,7 +754,7 @@ struct Std140::State { | |||||||
|                 }, |                 }, | ||||||
|                 [&](Default) { |                 [&](Default) { | ||||||
|                     TINT_ICE(Transform, b.Diagnostics()) |                     TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                         << "unhandled type for conversion: " << ctx.src->FriendlyName(ty); |                         << "unhandled type for conversion: " << src->FriendlyName(ty); | ||||||
|                 }); |                 }); | ||||||
| 
 | 
 | ||||||
|             // Generate the function
 |             // Generate the function
 | ||||||
| @ -1063,7 +1068,7 @@ struct Std140::State { | |||||||
| 
 | 
 | ||||||
|         if (std::get_if<UniformVariable>(&access)) { |         if (std::get_if<UniformVariable>(&access)) { | ||||||
|             const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol)); |             const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol)); | ||||||
|             const auto name = ctx.src->Symbols().NameFor(chain.var->Declaration()->symbol); |             const auto name = src->Symbols().NameFor(chain.var->Declaration()->symbol); | ||||||
|             ty = chain.var->Type()->UnwrapRef(); |             ty = chain.var->Type()->UnwrapRef(); | ||||||
|             return {expr, ty, name}; |             return {expr, ty, name}; | ||||||
|         } |         } | ||||||
| @ -1090,7 +1095,7 @@ struct Std140::State { | |||||||
|                 },  //
 |                 },  //
 | ||||||
|                 [&](Default) -> ExprTypeName { |                 [&](Default) -> ExprTypeName { | ||||||
|                     TINT_ICE(Transform, b.Diagnostics()) |                     TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                         << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); |                         << "unhandled type for access chain: " << src->FriendlyName(ty); | ||||||
|                     return {}; |                     return {}; | ||||||
|                 }); |                 }); | ||||||
|         } |         } | ||||||
| @ -1104,14 +1109,14 @@ struct Std140::State { | |||||||
|                     for (auto el : *swizzle) { |                     for (auto el : *swizzle) { | ||||||
|                         rhs += xyzw[el]; |                         rhs += xyzw[el]; | ||||||
|                     } |                     } | ||||||
|                     auto swizzle_ty = ctx.src->Types().Find<sem::Vector>( |                     auto swizzle_ty = src->Types().Find<sem::Vector>( | ||||||
|                         vec->type(), static_cast<uint32_t>(swizzle->Length())); |                         vec->type(), static_cast<uint32_t>(swizzle->Length())); | ||||||
|                     auto* expr = b.MemberAccessor(lhs, rhs); |                     auto* expr = b.MemberAccessor(lhs, rhs); | ||||||
|                     return {expr, swizzle_ty, rhs}; |                     return {expr, swizzle_ty, rhs}; | ||||||
|                 },  //
 |                 },  //
 | ||||||
|                 [&](Default) -> ExprTypeName { |                 [&](Default) -> ExprTypeName { | ||||||
|                     TINT_ICE(Transform, b.Diagnostics()) |                     TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                         << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); |                         << "unhandled type for access chain: " << src->FriendlyName(ty); | ||||||
|                     return {}; |                     return {}; | ||||||
|                 }); |                 }); | ||||||
|         } |         } | ||||||
| @ -1140,7 +1145,7 @@ struct Std140::State { | |||||||
|             },  //
 |             },  //
 | ||||||
|             [&](Default) -> ExprTypeName { |             [&](Default) -> ExprTypeName { | ||||||
|                 TINT_ICE(Transform, b.Diagnostics()) |                 TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                     << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); |                     << "unhandled type for access chain: " << src->FriendlyName(ty); | ||||||
|                 return {}; |                 return {}; | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| @ -1150,12 +1155,8 @@ Std140::Std140() = default; | |||||||
| 
 | 
 | ||||||
| Std140::~Std140() = default; | Std140::~Std140() = default; | ||||||
| 
 | 
 | ||||||
| bool Std140::ShouldRun(const Program* program, const DataMap&) const { | Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     return State::ShouldRun(program); |     return State(src).Run(); | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     State(ctx).Run(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -34,21 +34,13 @@ class Std140 final : public Castable<Std140, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~Std140() override; |     ~Std140() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     struct State; |     struct State; | ||||||
| 
 |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ | |||||||
| #include "src/tint/transform/substitute_override.h" | #include "src/tint/transform/substitute_override.h" | ||||||
| 
 | 
 | ||||||
| #include <functional> | #include <functional> | ||||||
|  | #include <utility> | ||||||
| 
 | 
 | ||||||
| #include "src/tint/program_builder.h" | #include "src/tint/program_builder.h" | ||||||
| #include "src/tint/sem/builtin.h" | #include "src/tint/sem/builtin.h" | ||||||
| @ -25,12 +26,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride); | |||||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config); | TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config); | ||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
|  | namespace { | ||||||
| 
 | 
 | ||||||
| SubstituteOverride::SubstituteOverride() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| SubstituteOverride::~SubstituteOverride() = default; |  | ||||||
| 
 |  | ||||||
| bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->AST().GlobalVariables()) { |     for (auto* node : program->AST().GlobalVariables()) { | ||||||
|         if (node->Is<ast::Override>()) { |         if (node->Is<ast::Override>()) { | ||||||
|             return true; |             return true; | ||||||
| @ -39,18 +37,32 @@ bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) const { | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | SubstituteOverride::SubstituteOverride() = default; | ||||||
|  | 
 | ||||||
|  | SubstituteOverride::~SubstituteOverride() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult SubstituteOverride::Apply(const Program* src, | ||||||
|  |                                                  const DataMap& config, | ||||||
|  |                                                  DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     const auto* data = config.Get<Config>(); |     const auto* data = config.Get<Config>(); | ||||||
|     if (!data) { |     if (!data) { | ||||||
|         ctx.dst->Diagnostics().add_error(diag::System::Transform, |         b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data"); | ||||||
|                                          "Missing override substitution data"); |         return Program(std::move(b)); | ||||||
|         return; |     } | ||||||
|  | 
 | ||||||
|  |     if (!ShouldRun(ctx.src)) { | ||||||
|  |         return SkipTransform; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* { |     ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* { | ||||||
|         auto* sem = ctx.src->Sem().Get(w); |         auto* sem = ctx.src->Sem().Get(w); | ||||||
| 
 | 
 | ||||||
|         auto src = ctx.Clone(w->source); |         auto source = ctx.Clone(w->source); | ||||||
|         auto sym = ctx.Clone(w->symbol); |         auto sym = ctx.Clone(w->symbol); | ||||||
|         auto* ty = ctx.Clone(w->type); |         auto* ty = ctx.Clone(w->type); | ||||||
| 
 | 
 | ||||||
| @ -58,30 +70,30 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) | |||||||
|         auto iter = data->map.find(sem->OverrideId()); |         auto iter = data->map.find(sem->OverrideId()); | ||||||
|         if (iter == data->map.end()) { |         if (iter == data->map.end()) { | ||||||
|             if (!w->initializer) { |             if (!w->initializer) { | ||||||
|                 ctx.dst->Diagnostics().add_error( |                 b.Diagnostics().add_error( | ||||||
|                     diag::System::Transform, |                     diag::System::Transform, | ||||||
|                     "Initializer not provided for override, and override not overridden."); |                     "Initializer not provided for override, and override not overridden."); | ||||||
|                 return nullptr; |                 return nullptr; | ||||||
|             } |             } | ||||||
|             return ctx.dst->Const(src, sym, ty, ctx.Clone(w->initializer)); |             return b.Const(source, sym, ty, ctx.Clone(w->initializer)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         auto value = iter->second; |         auto value = iter->second; | ||||||
|         auto* ctor = Switch( |         auto* ctor = Switch( | ||||||
|             sem->Type(), |             sem->Type(), | ||||||
|             [&](const sem::Bool*) { return ctx.dst->Expr(!std::equal_to<double>()(value, 0.0)); }, |             [&](const sem::Bool*) { return b.Expr(!std::equal_to<double>()(value, 0.0)); }, | ||||||
|             [&](const sem::I32*) { return ctx.dst->Expr(i32(value)); }, |             [&](const sem::I32*) { return b.Expr(i32(value)); }, | ||||||
|             [&](const sem::U32*) { return ctx.dst->Expr(u32(value)); }, |             [&](const sem::U32*) { return b.Expr(u32(value)); }, | ||||||
|             [&](const sem::F32*) { return ctx.dst->Expr(f32(value)); }, |             [&](const sem::F32*) { return b.Expr(f32(value)); }, | ||||||
|             [&](const sem::F16*) { return ctx.dst->Expr(f16(value)); }); |             [&](const sem::F16*) { return b.Expr(f16(value)); }); | ||||||
| 
 | 
 | ||||||
|         if (!ctor) { |         if (!ctor) { | ||||||
|             ctx.dst->Diagnostics().add_error(diag::System::Transform, |             b.Diagnostics().add_error(diag::System::Transform, | ||||||
|                                              "Failed to create override-expression"); |                                       "Failed to create override-expression"); | ||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return ctx.dst->Const(src, sym, ty, ctor); |         return b.Const(source, sym, ty, ctor); | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     // Ensure that objects that are indexed with an override-expression are materialized.
 |     // Ensure that objects that are indexed with an override-expression are materialized.
 | ||||||
| @ -89,11 +101,10 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) | |||||||
|     // resulting type of the index may change. See: crbug.com/tint/1697.
 |     // resulting type of the index may change. See: crbug.com/tint/1697.
 | ||||||
|     ctx.ReplaceAll( |     ctx.ReplaceAll( | ||||||
|         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { |         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { | ||||||
|             if (auto* sem = ctx.src->Sem().Get(expr)) { |             if (auto* sem = src->Sem().Get(expr)) { | ||||||
|                 if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) { |                 if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) { | ||||||
|                     if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() && |                     if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() && | ||||||
|                         access->Index()->Stage() == sem::EvaluationStage::kOverride) { |                         access->Index()->Stage() == sem::EvaluationStage::kOverride) { | ||||||
|                         auto& b = *ctx.dst; |  | ||||||
|                         auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize), |                         auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize), | ||||||
|                                            ctx.Clone(expr->object)); |                                            ctx.Clone(expr->object)); | ||||||
|                         return b.IndexAccessor(obj, ctx.Clone(expr->index)); |                         return b.IndexAccessor(obj, ctx.Clone(expr->index)); | ||||||
| @ -104,6 +115,7 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) | |||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| SubstituteOverride::Config::Config() = default; | SubstituteOverride::Config::Config() = default; | ||||||
|  | |||||||
| @ -75,19 +75,10 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform> | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~SubstituteOverride() override; |     ~SubstituteOverride() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -122,7 +122,18 @@ class TransformTestBase : public BASE { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         const Transform& t = TRANSFORM(); |         const Transform& t = TRANSFORM(); | ||||||
|         return t.ShouldRun(&program, data); | 
 | ||||||
|  |         DataMap outputs; | ||||||
|  |         auto result = t.Apply(&program, data, outputs); | ||||||
|  |         if (!result) { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |         if (!result->IsValid()) { | ||||||
|  |             ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: " | ||||||
|  |                           << result->Diagnostics().str(); | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |         return result.has_value(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// @param in the input WGSL source
 |     /// @param in the input WGSL source
 | ||||||
|  | |||||||
| @ -46,24 +46,19 @@ Output::Output(Program&& p) : program(std::move(p)) {} | |||||||
| Transform::Transform() = default; | Transform::Transform() = default; | ||||||
| Transform::~Transform() = default; | Transform::~Transform() = default; | ||||||
| 
 | 
 | ||||||
| Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const { | Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const { | ||||||
|     ProgramBuilder builder; |  | ||||||
|     CloneContext ctx(&builder, program); |  | ||||||
|     Output output; |     Output output; | ||||||
|     Run(ctx, data, output.data); |     if (auto program = Apply(src, data, output.data)) { | ||||||
|     output.program = Program(std::move(builder)); |         output.program = std::move(program.value()); | ||||||
|  |     } else { | ||||||
|  |         ProgramBuilder b; | ||||||
|  |         CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  |         ctx.Clone(); | ||||||
|  |         output.program = Program(std::move(b)); | ||||||
|  |     } | ||||||
|     return output; |     return output; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |  | ||||||
|     TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics()) |  | ||||||
|         << "Transform::Run() unimplemented for " << TypeInfo().name; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| bool Transform::ShouldRun(const Program*, const DataMap&) const { |  | ||||||
|     return true; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) { | void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) { | ||||||
|     auto* sem = ctx.src->Sem().Get(stmt); |     auto* sem = ctx.src->Sem().Get(stmt); | ||||||
|     if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) { |     if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) { | ||||||
|  | |||||||
| @ -158,26 +158,30 @@ class Transform : public Castable<Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~Transform() override; |     ~Transform() override; | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform on `program`, returning the transformation result.
 |     /// Runs the transform on @p program, returning the transformation result or a clone of
 | ||||||
|  |     /// @p program.
 | ||||||
|     /// @param program the source program to transform
 |     /// @param program the source program to transform
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     /// @param data optional extra transform-specific input data
 | ||||||
|     /// @returns the transformation result
 |     /// @returns the transformation result
 | ||||||
|     virtual Output Run(const Program* program, const DataMap& data = {}) const; |     Output Run(const Program* program, const DataMap& data = {}) const; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// The return value of Apply().
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     /// If SkipTransform (std::nullopt), then the transform is not needed to be run.
 | ||||||
|     /// @returns true if this transform should be run for the given program
 |     using ApplyResult = std::optional<Program>; | ||||||
|     virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const; |  | ||||||
| 
 | 
 | ||||||
|   protected: |     /// Value returned from Apply() to indicate that the transform does not need to be run
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     static inline constexpr std::nullopt_t SkipTransform = std::nullopt; | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | 
 | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |     /// Runs the transform on `program`, return.
 | ||||||
|     /// ProgramBuilder
 |     /// @param program the input program
 | ||||||
|     /// @param inputs optional extra transform-specific input data
 |     /// @param inputs optional extra transform-specific input data
 | ||||||
|     /// @param outputs optional extra transform-specific output data
 |     /// @param outputs optional extra transform-specific output data
 | ||||||
|     virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const; |     /// @returns a transformed program, or std::nullopt if the transform didn't need to run.
 | ||||||
|  |     virtual ApplyResult Apply(const Program* program, | ||||||
|  |                               const DataMap& inputs, | ||||||
|  |                               DataMap& outputs) const = 0; | ||||||
| 
 | 
 | ||||||
|  |   protected: | ||||||
|     /// Removes the statement `stmt` from the transformed program.
 |     /// Removes the statement `stmt` from the transformed program.
 | ||||||
|     /// RemoveStatement handles edge cases, like statements in the initializer and
 |     /// RemoveStatement handles edge cases, like statements in the initializer and
 | ||||||
|     /// continuing of for-loops.
 |     /// continuing of for-loops.
 | ||||||
|  | |||||||
| @ -23,7 +23,9 @@ namespace { | |||||||
| 
 | 
 | ||||||
| // Inherit from Transform so we have access to protected methods
 | // Inherit from Transform so we have access to protected methods
 | ||||||
| struct CreateASTTypeForTest : public testing::Test, public Transform { | struct CreateASTTypeForTest : public testing::Test, public Transform { | ||||||
|     Output Run(const Program*, const DataMap&) const override { return {}; } |     ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) { |     const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) { | ||||||
|         ProgramBuilder sem_type_builder; |         ProgramBuilder sem_type_builder; | ||||||
|  | |||||||
| @ -28,27 +28,32 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow); | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// The PIMPL state for the Unshadow transform
 | /// PIMPL state for the transform
 | ||||||
| struct Unshadow::State { | struct Unshadow::State { | ||||||
|  |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|     /// The clone context
 |     /// The clone context
 | ||||||
|     CloneContext& ctx; |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
| 
 | 
 | ||||||
|     /// Constructor
 |     /// Constructor
 | ||||||
|     /// @param context the clone context
 |     /// @param program the source program
 | ||||||
|     explicit State(CloneContext& context) : ctx(context) {} |     explicit State(const Program* program) : src(program) {} | ||||||
| 
 | 
 | ||||||
|     /// Performs the transformation
 |     /// Runs the transform
 | ||||||
|     void Run() { |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|         auto& sem = ctx.src->Sem(); |     Transform::ApplyResult Run() { | ||||||
|  |         auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|         // Maps a variable to its new name.
 |         // Maps a variable to its new name.
 | ||||||
|         std::unordered_map<const sem::Variable*, Symbol> renamed_to; |         utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to; | ||||||
| 
 | 
 | ||||||
|         auto rename = [&](const sem::Variable* v) -> const ast::Variable* { |         auto rename = [&](const sem::Variable* v) -> const ast::Variable* { | ||||||
|             auto* decl = v->Declaration(); |             auto* decl = v->Declaration(); | ||||||
|             auto name = ctx.src->Symbols().NameFor(decl->symbol); |             auto name = src->Symbols().NameFor(decl->symbol); | ||||||
|             auto symbol = ctx.dst->Symbols().New(name); |             auto symbol = b.Symbols().New(name); | ||||||
|             renamed_to.emplace(v, symbol); |             renamed_to.Add(v, symbol); | ||||||
| 
 | 
 | ||||||
|             auto source = ctx.Clone(decl->source); |             auto source = ctx.Clone(decl->source); | ||||||
|             auto* type = ctx.Clone(decl->type); |             auto* type = ctx.Clone(decl->type); | ||||||
| @ -57,20 +62,20 @@ struct Unshadow::State { | |||||||
|             return Switch( |             return Switch( | ||||||
|                 decl,  //
 |                 decl,  //
 | ||||||
|                 [&](const ast::Var* var) { |                 [&](const ast::Var* var) { | ||||||
|                     return ctx.dst->Var(source, symbol, type, var->declared_address_space, |                     return b.Var(source, symbol, type, var->declared_address_space, | ||||||
|                                         var->declared_access, initializer, attributes); |                                  var->declared_access, initializer, attributes); | ||||||
|                 }, |                 }, | ||||||
|                 [&](const ast::Let*) { |                 [&](const ast::Let*) { | ||||||
|                     return ctx.dst->Let(source, symbol, type, initializer, attributes); |                     return b.Let(source, symbol, type, initializer, attributes); | ||||||
|                 }, |                 }, | ||||||
|                 [&](const ast::Const*) { |                 [&](const ast::Const*) { | ||||||
|                     return ctx.dst->Const(source, symbol, type, initializer, attributes); |                     return b.Const(source, symbol, type, initializer, attributes); | ||||||
|                 }, |                 }, | ||||||
|                 [&](const ast::Parameter*) { |                 [&](const ast::Parameter*) {  //
 | ||||||
|                     return ctx.dst->Param(source, symbol, type, attributes); |                     return b.Param(source, symbol, type, attributes); | ||||||
|                 }, |                 }, | ||||||
|                 [&](Default) { |                 [&](Default) { | ||||||
|                     TINT_ICE(Transform, ctx.dst->Diagnostics()) |                     TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                         << "unexpected variable type: " << decl->TypeInfo().name; |                         << "unexpected variable type: " << decl->TypeInfo().name; | ||||||
|                     return nullptr; |                     return nullptr; | ||||||
|                 }); |                 }); | ||||||
| @ -92,14 +97,15 @@ struct Unshadow::State { | |||||||
|         ctx.ReplaceAll( |         ctx.ReplaceAll( | ||||||
|             [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* { |             [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* { | ||||||
|                 if (auto* user = sem.Get<sem::VariableUser>(ident)) { |                 if (auto* user = sem.Get<sem::VariableUser>(ident)) { | ||||||
|                     auto it = renamed_to.find(user->Variable()); |                     if (auto* renamed = renamed_to.Find(user->Variable())) { | ||||||
|                     if (it != renamed_to.end()) { |                         return b.Expr(*renamed); | ||||||
|                         return ctx.dst->Expr(it->second); |  | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 return nullptr; |                 return nullptr; | ||||||
|             }); |             }); | ||||||
|  | 
 | ||||||
|         ctx.Clone(); |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| @ -107,8 +113,8 @@ Unshadow::Unshadow() = default; | |||||||
| 
 | 
 | ||||||
| Unshadow::~Unshadow() = default; | Unshadow::~Unshadow() = default; | ||||||
| 
 | 
 | ||||||
| void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|     State(ctx).Run(); |     return State(src).Run(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -29,16 +29,13 @@ class Unshadow final : public Castable<Unshadow, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~Unshadow() override; |     ~Unshadow() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     struct State; |     ApplyResult Apply(const Program* program, | ||||||
|  |                       const DataMap& inputs, | ||||||
|  |                       DataMap& outputs) const override; | ||||||
| 
 | 
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |   private: | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |     struct State; | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -35,7 +35,51 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions); | |||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| class State { | bool ShouldRun(const Program* program) { | ||||||
|  |     auto& sem = program->Sem(); | ||||||
|  |     for (auto* f : program->AST().Functions()) { | ||||||
|  |         if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | /// PIMPL state for the transform
 | ||||||
|  | struct UnwindDiscardFunctions::State { | ||||||
|  |     /// Constructor
 | ||||||
|  |     /// @param ctx_in the context
 | ||||||
|  |     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} | ||||||
|  | 
 | ||||||
|  |     /// Runs the transform
 | ||||||
|  |     void Run() { | ||||||
|  |         ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { | ||||||
|  |             // Iterate block statements and replace them as needed.
 | ||||||
|  |             for (auto* stmt : block->statements) { | ||||||
|  |                 if (auto* new_stmt = Statement(stmt)) { | ||||||
|  |                     ctx.Replace(stmt, new_stmt); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 // Handle for loops, as they are the only other AST node that
 | ||||||
|  |                 // contains statements outside of BlockStatements.
 | ||||||
|  |                 if (auto* fl = stmt->As<ast::ForLoopStatement>()) { | ||||||
|  |                     if (auto* new_stmt = Statement(fl->initializer)) { | ||||||
|  |                         ctx.Replace(fl->initializer, new_stmt); | ||||||
|  |                     } | ||||||
|  |                     if (auto* new_stmt = Statement(fl->continuing)) { | ||||||
|  |                         // NOTE: Should never reach here as we cannot discard in a
 | ||||||
|  |                         // continuing block.
 | ||||||
|  |                         ctx.Replace(fl->continuing, new_stmt); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             return nullptr; | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|   private: |   private: | ||||||
|     CloneContext& ctx; |     CloneContext& ctx; | ||||||
|     ProgramBuilder& b; |     ProgramBuilder& b; | ||||||
| @ -163,7 +207,7 @@ class State { | |||||||
|     // Returns true if `stmt` is a for-loop initializer statement.
 |     // Returns true if `stmt` is a for-loop initializer statement.
 | ||||||
|     bool IsForLoopInitStatement(const ast::Statement* stmt) { |     bool IsForLoopInitStatement(const ast::Statement* stmt) { | ||||||
|         if (auto* sem_stmt = sem.Get(stmt)) { |         if (auto* sem_stmt = sem.Get(stmt)) { | ||||||
|             if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) { |             if (auto* sem_fl = tint::As<sem::ForLoopStatement>(sem_stmt->Parent())) { | ||||||
|                 return sem_fl->Declaration()->initializer == stmt; |                 return sem_fl->Declaration()->initializer == stmt; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @ -305,60 +349,26 @@ class State { | |||||||
|                 return TryInsertAfter(s, sem_expr); |                 return TryInsertAfter(s, sem_expr); | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| 
 |  | ||||||
|   public: |  | ||||||
|     /// Constructor
 |  | ||||||
|     /// @param ctx_in the context
 |  | ||||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} |  | ||||||
| 
 |  | ||||||
|     /// Runs the transform
 |  | ||||||
|     void Run() { |  | ||||||
|         ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { |  | ||||||
|             // Iterate block statements and replace them as needed.
 |  | ||||||
|             for (auto* stmt : block->statements) { |  | ||||||
|                 if (auto* new_stmt = Statement(stmt)) { |  | ||||||
|                     ctx.Replace(stmt, new_stmt); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 // Handle for loops, as they are the only other AST node that
 |  | ||||||
|                 // contains statements outside of BlockStatements.
 |  | ||||||
|                 if (auto* fl = stmt->As<ast::ForLoopStatement>()) { |  | ||||||
|                     if (auto* new_stmt = Statement(fl->initializer)) { |  | ||||||
|                         ctx.Replace(fl->initializer, new_stmt); |  | ||||||
|                     } |  | ||||||
|                     if (auto* new_stmt = Statement(fl->continuing)) { |  | ||||||
|                         // NOTE: Should never reach here as we cannot discard in a
 |  | ||||||
|                         // continuing block.
 |  | ||||||
|                         ctx.Replace(fl->continuing, new_stmt); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             return nullptr; |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         ctx.Clone(); |  | ||||||
|     } |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| UnwindDiscardFunctions::UnwindDiscardFunctions() = default; | UnwindDiscardFunctions::UnwindDiscardFunctions() = default; | ||||||
| UnwindDiscardFunctions::~UnwindDiscardFunctions() = default; | UnwindDiscardFunctions::~UnwindDiscardFunctions() = default; | ||||||
| 
 | 
 | ||||||
| void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult UnwindDiscardFunctions::Apply(const Program* src, | ||||||
|  |                                                      const DataMap&, | ||||||
|  |                                                      DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     State state(ctx); |     State state(ctx); | ||||||
|     state.Run(); |     state.Run(); | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const { |     ctx.Clone(); | ||||||
|     auto& sem = program->Sem(); |     return Program(std::move(b)); | ||||||
|     for (auto* f : program->AST().Functions()) { |  | ||||||
|         if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { |  | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     return false; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -44,19 +44,13 @@ class UnwindDiscardFunctions final : public Castable<UnwindDiscardFunctions, Tra | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~UnwindDiscardFunctions() override; |     ~UnwindDiscardFunctions() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |   private: | ||||||
|     /// @param data optional extra transform-specific input data
 |     struct State; | ||||||
|     /// @returns true if this transform should be run for the given program
 |  | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -30,7 +30,59 @@ | |||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| /// Private implementation of HoistToDeclBefore transform
 | /// Private implementation of HoistToDeclBefore transform
 | ||||||
| class HoistToDeclBefore::State { | struct HoistToDeclBefore::State { | ||||||
|  |     /// Constructor
 | ||||||
|  |     /// @param ctx_in the clone context
 | ||||||
|  |     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} | ||||||
|  | 
 | ||||||
|  |     /// @copydoc HoistToDeclBefore::Add()
 | ||||||
|  |     bool Add(const sem::Expression* before_expr, | ||||||
|  |              const ast::Expression* expr, | ||||||
|  |              bool as_let, | ||||||
|  |              const char* decl_name) { | ||||||
|  |         auto name = b.Symbols().New(decl_name); | ||||||
|  | 
 | ||||||
|  |         if (as_let) { | ||||||
|  |             auto builder = [this, expr, name] { | ||||||
|  |                 return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr))); | ||||||
|  |             }; | ||||||
|  |             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             auto builder = [this, expr, name] { | ||||||
|  |                 return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr))); | ||||||
|  |             }; | ||||||
|  |             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Replace the initializer expression with a reference to the let
 | ||||||
|  |         ctx.Replace(expr, b.Expr(name)); | ||||||
|  |         return true; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
 | ||||||
|  |     bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { | ||||||
|  |         if (stmt) { | ||||||
|  |             auto builder = [stmt] { return stmt; }; | ||||||
|  |             return InsertBeforeImpl(before_stmt, std::move(builder)); | ||||||
|  |         } | ||||||
|  |         return InsertBeforeImpl(before_stmt, Decompose{}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
 | ||||||
|  |     bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) { | ||||||
|  |         return InsertBeforeImpl(before_stmt, std::move(builder)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /// @copydoc HoistToDeclBefore::Prepare()
 | ||||||
|  |     bool Prepare(const sem::Expression* before_expr) { | ||||||
|  |         return InsertBefore(before_expr->Stmt(), nullptr); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   private: | ||||||
|     CloneContext& ctx; |     CloneContext& ctx; | ||||||
|     ProgramBuilder& b; |     ProgramBuilder& b; | ||||||
| 
 | 
 | ||||||
| @ -215,6 +267,8 @@ class HoistToDeclBefore::State { | |||||||
| 
 | 
 | ||||||
|     template <typename BUILDER> |     template <typename BUILDER> | ||||||
|     bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) { |     bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) { | ||||||
|  |         (void)builder;  // Avoid 'unused parameter' warning due to 'if constexpr'
 | ||||||
|  | 
 | ||||||
|         auto* ip = before_stmt->Declaration(); |         auto* ip = before_stmt->Declaration(); | ||||||
| 
 | 
 | ||||||
|         auto* else_if = before_stmt->As<sem::IfStatement>(); |         auto* else_if = before_stmt->As<sem::IfStatement>(); | ||||||
| @ -299,58 +353,6 @@ class HoistToDeclBefore::State { | |||||||
|             << "unhandled expression parent statement type: " << parent->TypeInfo().name; |             << "unhandled expression parent statement type: " << parent->TypeInfo().name; | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
| 
 |  | ||||||
|   public: |  | ||||||
|     /// Constructor
 |  | ||||||
|     /// @param ctx_in the clone context
 |  | ||||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} |  | ||||||
| 
 |  | ||||||
|     /// @copydoc HoistToDeclBefore::Add()
 |  | ||||||
|     bool Add(const sem::Expression* before_expr, |  | ||||||
|              const ast::Expression* expr, |  | ||||||
|              bool as_let, |  | ||||||
|              const char* decl_name) { |  | ||||||
|         auto name = b.Symbols().New(decl_name); |  | ||||||
| 
 |  | ||||||
|         if (as_let) { |  | ||||||
|             auto builder = [this, expr, name] { |  | ||||||
|                 return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr))); |  | ||||||
|             }; |  | ||||||
|             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { |  | ||||||
|                 return false; |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             auto builder = [this, expr, name] { |  | ||||||
|                 return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr))); |  | ||||||
|             }; |  | ||||||
|             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { |  | ||||||
|                 return false; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // Replace the initializer expression with a reference to the let
 |  | ||||||
|         ctx.Replace(expr, b.Expr(name)); |  | ||||||
|         return true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
 |  | ||||||
|     bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { |  | ||||||
|         if (stmt) { |  | ||||||
|             auto builder = [stmt] { return stmt; }; |  | ||||||
|             return InsertBeforeImpl(before_stmt, std::move(builder)); |  | ||||||
|         } |  | ||||||
|         return InsertBeforeImpl(before_stmt, Decompose{}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
 |  | ||||||
|     bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) { |  | ||||||
|         return InsertBeforeImpl(before_stmt, std::move(builder)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /// @copydoc HoistToDeclBefore::Prepare()
 |  | ||||||
|     bool Prepare(const sem::Expression* before_expr) { |  | ||||||
|         return InsertBefore(before_expr->Stmt(), nullptr); |  | ||||||
|     } |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {} | HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {} | ||||||
|  | |||||||
| @ -77,7 +77,7 @@ class HoistToDeclBefore { | |||||||
|     bool Prepare(const sem::Expression* before_expr); |     bool Prepare(const sem::Expression* before_expr); | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     class State; |     struct State; | ||||||
|     std::unique_ptr<State> state_; |     std::unique_ptr<State> state_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -13,6 +13,9 @@ | |||||||
| // limitations under the License.
 | // limitations under the License.
 | ||||||
| 
 | 
 | ||||||
| #include "src/tint/transform/var_for_dynamic_index.h" | #include "src/tint/transform/var_for_dynamic_index.h" | ||||||
|  | 
 | ||||||
|  | #include <utility> | ||||||
|  | 
 | ||||||
| #include "src/tint/program_builder.h" | #include "src/tint/program_builder.h" | ||||||
| #include "src/tint/transform/utils/hoist_to_decl_before.h" | #include "src/tint/transform/utils/hoist_to_decl_before.h" | ||||||
| 
 | 
 | ||||||
| @ -22,7 +25,12 @@ VarForDynamicIndex::VarForDynamicIndex() = default; | |||||||
| 
 | 
 | ||||||
| VarForDynamicIndex::~VarForDynamicIndex() = default; | VarForDynamicIndex::~VarForDynamicIndex() = default; | ||||||
| 
 | 
 | ||||||
| void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src, | ||||||
|  |                                                  const DataMap&, | ||||||
|  |                                                  DataMap&) const { | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     HoistToDeclBefore hoist_to_decl_before(ctx); |     HoistToDeclBefore hoist_to_decl_before(ctx); | ||||||
| 
 | 
 | ||||||
|     // Extracts array and matrix values that are dynamically indexed to a
 |     // Extracts array and matrix values that are dynamically indexed to a
 | ||||||
| @ -30,7 +38,7 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const | |||||||
|     auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) { |     auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) { | ||||||
|         auto* index_expr = access_expr->index; |         auto* index_expr = access_expr->index; | ||||||
|         auto* object_expr = access_expr->object; |         auto* object_expr = access_expr->object; | ||||||
|         auto& sem = ctx.src->Sem(); |         auto& sem = src->Sem(); | ||||||
| 
 | 
 | ||||||
|         if (sem.Get(index_expr)->ConstantValue()) { |         if (sem.Get(index_expr)->ConstantValue()) { | ||||||
|             // Index expression resolves to a compile time value.
 |             // Index expression resolves to a compile time value.
 | ||||||
| @ -49,15 +57,21 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const | |||||||
|         return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index"); |         return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index"); | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { |     bool index_accessor_found = false; | ||||||
|  |     for (auto* node : src->ASTNodes().Objects()) { | ||||||
|         if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) { |         if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) { | ||||||
|             if (!dynamic_index_to_var(access_expr)) { |             if (!dynamic_index_to_var(access_expr)) { | ||||||
|                 return; |                 return Program(std::move(b)); | ||||||
|             } |             } | ||||||
|  |             index_accessor_found = true; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |     if (!index_accessor_found) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -31,14 +31,10 @@ class VarForDynamicIndex : public Transform { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~VarForDynamicIndex() override; |     ~VarForDynamicIndex() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -30,11 +30,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeMatrixConversions); | |||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
| 
 | 
 | ||||||
| VectorizeMatrixConversions::VectorizeMatrixConversions() = default; | namespace { | ||||||
| 
 | 
 | ||||||
| VectorizeMatrixConversions::~VectorizeMatrixConversions() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|         if (auto* sem = program->Sem().Get<sem::Expression>(node)) { |         if (auto* sem = program->Sem().Get<sem::Expression>(node)) { | ||||||
|             if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) { |             if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) { | ||||||
| @ -50,14 +48,29 @@ bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | VectorizeMatrixConversions::VectorizeMatrixConversions() = default; | ||||||
|  | 
 | ||||||
|  | VectorizeMatrixConversions::~VectorizeMatrixConversions() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src, | ||||||
|  |                                                          const DataMap&, | ||||||
|  |                                                          DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     using HelperFunctionKey = |     using HelperFunctionKey = | ||||||
|         utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>; |         utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>; | ||||||
| 
 | 
 | ||||||
|     std::unordered_map<HelperFunctionKey, Symbol> matrix_convs; |     std::unordered_map<HelperFunctionKey, Symbol> matrix_convs; | ||||||
| 
 | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { |     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { | ||||||
|         auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); |         auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||||
|         auto* ty_conv = call->Target()->As<sem::TypeConversion>(); |         auto* ty_conv = call->Target()->As<sem::TypeConversion>(); | ||||||
|         if (!ty_conv) { |         if (!ty_conv) { | ||||||
|             return nullptr; |             return nullptr; | ||||||
| @ -72,16 +85,16 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap& | |||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         auto& src = args[0]; |         auto& matrix = args[0]; | ||||||
| 
 | 
 | ||||||
|         auto* src_type = args[0]->Type()->UnwrapRef()->As<sem::Matrix>(); |         auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>(); | ||||||
|         if (!src_type) { |         if (!src_type) { | ||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // The source and destination type of a matrix conversion must have a same shape.
 |         // The source and destination type of a matrix conversion must have a same shape.
 | ||||||
|         if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) { |         if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) { | ||||||
|             TINT_ICE(Transform, ctx.dst->Diagnostics()) |             TINT_ICE(Transform, b.Diagnostics()) | ||||||
|                 << "source and destination matrix has different shape in matrix conversion"; |                 << "source and destination matrix has different shape in matrix conversion"; | ||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
| @ -90,47 +103,45 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap& | |||||||
|             utils::Vector<const ast::Expression*, 4> columns; |             utils::Vector<const ast::Expression*, 4> columns; | ||||||
|             for (uint32_t c = 0; c < dst_type->columns(); c++) { |             for (uint32_t c = 0; c < dst_type->columns(); c++) { | ||||||
|                 auto* src_matrix_expr = src_expression_builder(); |                 auto* src_matrix_expr = src_expression_builder(); | ||||||
|                 auto* src_column_expr = |                 auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c))); | ||||||
|                     ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c))); |                 columns.Push( | ||||||
|                 columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), |                     b.Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr)); | ||||||
|                                                 src_column_expr)); |  | ||||||
|             } |             } | ||||||
|             return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns); |             return b.Construct(CreateASTTypeFor(ctx, dst_type), columns); | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Replace the matrix conversion to column vector conversions and a matrix construction.
 |         // Replace the matrix conversion to column vector conversions and a matrix construction.
 | ||||||
|         if (!src->HasSideEffects()) { |         if (!matrix->HasSideEffects()) { | ||||||
|             // Simply use the argument's declaration if it has no side effects.
 |             // Simply use the argument's declaration if it has no side effects.
 | ||||||
|             return build_vectorized_conversion_expression([&]() {  //
 |             return build_vectorized_conversion_expression([&]() {  //
 | ||||||
|                 return ctx.Clone(src->Declaration()); |                 return ctx.Clone(matrix->Declaration()); | ||||||
|             }); |             }); | ||||||
|         } else { |         } else { | ||||||
|             // If has side effects, use a helper function.
 |             // If has side effects, use a helper function.
 | ||||||
|             auto fn = |             auto fn = | ||||||
|                 utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] { |                 utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] { | ||||||
|                     auto name = |                     auto name = b.Symbols().New( | ||||||
|                         ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) + |                         "convert_mat" + std::to_string(src_type->columns()) + "x" + | ||||||
|                                                "x" + std::to_string(src_type->rows()) + "_" + |                         std::to_string(src_type->rows()) + "_" + b.FriendlyName(src_type->type()) + | ||||||
|                                                ctx.dst->FriendlyName(src_type->type()) + "_" + |                         "_" + b.FriendlyName(dst_type->type())); | ||||||
|                                                ctx.dst->FriendlyName(dst_type->type())); |                     b.Func(name, | ||||||
|                     ctx.dst->Func( |                            utils::Vector{ | ||||||
|                         name, |                                b.Param("value", CreateASTTypeFor(ctx, src_type)), | ||||||
|                         utils::Vector{ |                            }, | ||||||
|                             ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)), |                            CreateASTTypeFor(ctx, dst_type), | ||||||
|                         }, |                            utils::Vector{ | ||||||
|                         CreateASTTypeFor(ctx, dst_type), |                                b.Return(build_vectorized_conversion_expression([&]() {  //
 | ||||||
|                         utils::Vector{ |                                    return b.Expr("value"); | ||||||
|                             ctx.dst->Return(build_vectorized_conversion_expression([&]() {  //
 |                                })), | ||||||
|                                 return ctx.dst->Expr("value"); |                            }); | ||||||
|                             })), |  | ||||||
|                         }); |  | ||||||
|                     return name; |                     return name; | ||||||
|                 }); |                 }); | ||||||
|             return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration())); |             return b.Call(fn, ctx.Clone(args[0]->Declaration())); | ||||||
|         } |         } | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -28,19 +28,10 @@ class VectorizeMatrixConversions final : public Castable<VectorizeMatrixConversi | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~VectorizeMatrixConversions() override; |     ~VectorizeMatrixConversions() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -27,12 +27,9 @@ | |||||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixInitializers); | TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixInitializers); | ||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
|  | namespace { | ||||||
| 
 | 
 | ||||||
| VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default; |  | ||||||
| 
 |  | ||||||
| bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|         if (auto* call = program->Sem().Get<sem::Call>(node)) { |         if (auto* call = program->Sem().Get<sem::Call>(node)) { | ||||||
|             if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) { |             if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) { | ||||||
| @ -46,11 +43,26 @@ bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default; | ||||||
|  | 
 | ||||||
|  | VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src, | ||||||
|  |                                                                 const DataMap&, | ||||||
|  |                                                                 DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     std::unordered_map<const sem::Matrix*, Symbol> scalar_inits; |     std::unordered_map<const sem::Matrix*, Symbol> scalar_inits; | ||||||
| 
 | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { |     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { | ||||||
|         auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); |         auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||||
|         auto* ty_init = call->Target()->As<sem::TypeInitializer>(); |         auto* ty_init = call->Target()->As<sem::TypeInitializer>(); | ||||||
|         if (!ty_init) { |         if (!ty_init) { | ||||||
|             return nullptr; |             return nullptr; | ||||||
| @ -87,10 +99,10 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D | |||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 // Construct the column vector.
 |                 // Construct the column vector.
 | ||||||
|                 columns.Push(ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), |                 columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), | ||||||
|                                           std::move(row_values))); |                                    std::move(row_values))); | ||||||
|             } |             } | ||||||
|             return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns); |             return b.Construct(CreateASTTypeFor(ctx, mat_type), columns); | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         if (args.Length() == 1) { |         if (args.Length() == 1) { | ||||||
| @ -98,23 +110,22 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D | |||||||
|             // This is done to ensure that the single argument value is only evaluated once, and
 |             // This is done to ensure that the single argument value is only evaluated once, and
 | ||||||
|             // with the correct expression evaluation order.
 |             // with the correct expression evaluation order.
 | ||||||
|             auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] { |             auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] { | ||||||
|                 auto name = |                 auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) + | ||||||
|                     ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" + |                                             "x" + std::to_string(mat_type->rows())); | ||||||
|                                            std::to_string(mat_type->rows())); |                 b.Func(name, | ||||||
|                 ctx.dst->Func(name, |                        utils::Vector{ | ||||||
|                               utils::Vector{ |                            // Single scalar parameter
 | ||||||
|                                   // Single scalar parameter
 |                            b.Param("value", CreateASTTypeFor(ctx, mat_type->type())), | ||||||
|                                   ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())), |                        }, | ||||||
|                               }, |                        CreateASTTypeFor(ctx, mat_type), | ||||||
|                               CreateASTTypeFor(ctx, mat_type), |                        utils::Vector{ | ||||||
|                               utils::Vector{ |                            b.Return(build_mat([&](uint32_t, uint32_t) {  //
 | ||||||
|                                   ctx.dst->Return(build_mat([&](uint32_t, uint32_t) {  //
 |                                return b.Expr("value"); | ||||||
|                                       return ctx.dst->Expr("value"); |                            })), | ||||||
|                                   })), |                        }); | ||||||
|                               }); |  | ||||||
|                 return name; |                 return name; | ||||||
|             }); |             }); | ||||||
|             return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration())); |             return b.Call(fn, ctx.Clone(args[0]->Declaration())); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if (args.Length() == mat_type->columns() * mat_type->rows()) { |         if (args.Length() == mat_type->columns() * mat_type->rows()) { | ||||||
| @ -123,12 +134,13 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D | |||||||
|             }); |             }); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         TINT_ICE(Transform, ctx.dst->Diagnostics()) |         TINT_ICE(Transform, b.Diagnostics()) | ||||||
|             << "matrix initializer has unexpected number of arguments"; |             << "matrix initializer has unexpected number of arguments"; | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -29,19 +29,10 @@ class VectorizeScalarMatrixInitializers final | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~VectorizeScalarMatrixInitializers() override; |     ~VectorizeScalarMatrixInitializers() override; | ||||||
| 
 | 
 | ||||||
|     /// @param program the program to inspect
 |     /// @copydoc Transform::Apply
 | ||||||
|     /// @param data optional extra transform-specific input data
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// @returns true if this transform should be run for the given program
 |                       const DataMap& inputs, | ||||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; |                       DataMap& outputs) const override; | ||||||
| 
 |  | ||||||
|   protected: |  | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |  | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |  | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |  | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
| @ -201,13 +201,46 @@ DataType DataTypeOf(VertexFormat format) { | |||||||
|     return {BaseType::kInvalid, 0}; |     return {BaseType::kInvalid, 0}; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct State { | }  // namespace
 | ||||||
|     State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {} |  | ||||||
|     State(const State&) = default; |  | ||||||
|     ~State() = default; |  | ||||||
| 
 | 
 | ||||||
|     /// LocationReplacement describes an ast::Variable replacement for a
 | /// PIMPL state for the transform
 | ||||||
|     /// location input.
 | struct VertexPulling::State { | ||||||
|  |     /// Constructor
 | ||||||
|  |     /// @param program the source program
 | ||||||
|  |     /// @param c the VertexPulling config
 | ||||||
|  |     State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {} | ||||||
|  | 
 | ||||||
|  |     /// Runs the transform
 | ||||||
|  |     /// @returns the new program or SkipTransform if the transform is not required
 | ||||||
|  |     ApplyResult Run() { | ||||||
|  |         // Find entry point
 | ||||||
|  |         const ast::Function* func = nullptr; | ||||||
|  |         for (auto* fn : src->AST().Functions()) { | ||||||
|  |             if (fn->PipelineStage() == ast::PipelineStage::kVertex) { | ||||||
|  |                 if (func != nullptr) { | ||||||
|  |                     b.Diagnostics().add_error( | ||||||
|  |                         diag::System::Transform, | ||||||
|  |                         "VertexPulling found more than one vertex entry point"); | ||||||
|  |                     return Program(std::move(b)); | ||||||
|  |                 } | ||||||
|  |                 func = fn; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if (func == nullptr) { | ||||||
|  |             b.Diagnostics().add_error(diag::System::Transform, | ||||||
|  |                                       "Vertex stage entry point not found"); | ||||||
|  |             return Program(std::move(b)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         AddVertexStorageBuffers(); | ||||||
|  |         Process(func); | ||||||
|  | 
 | ||||||
|  |         ctx.Clone(); | ||||||
|  |         return Program(std::move(b)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   private: | ||||||
|  |     /// LocationReplacement describes an ast::Variable replacement for a location input.
 | ||||||
|     struct LocationReplacement { |     struct LocationReplacement { | ||||||
|         /// The variable to replace in the source Program
 |         /// The variable to replace in the source Program
 | ||||||
|         ast::Variable* from; |         ast::Variable* from; | ||||||
| @ -215,13 +248,22 @@ struct State { | |||||||
|         ast::Variable* to; |         ast::Variable* to; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     /// LocationInfo describes an input location
 | ||||||
|     struct LocationInfo { |     struct LocationInfo { | ||||||
|  |         /// A builder that builds the expression that resolves to the (transformed) input location
 | ||||||
|         std::function<const ast::Expression*()> expr; |         std::function<const ast::Expression*()> expr; | ||||||
|  |         /// The store type of the location variable
 | ||||||
|         const sem::Type* type; |         const sem::Type* type; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     CloneContext& ctx; |     /// The source program
 | ||||||
|  |     const Program* const src; | ||||||
|  |     /// The transform config
 | ||||||
|     VertexPulling::Config const cfg; |     VertexPulling::Config const cfg; | ||||||
|  |     /// The target program builder
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     /// The clone context
 | ||||||
|  |     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||||
|     std::unordered_map<uint32_t, LocationInfo> location_info; |     std::unordered_map<uint32_t, LocationInfo> location_info; | ||||||
|     std::function<const ast::Expression*()> vertex_index_expr = nullptr; |     std::function<const ast::Expression*()> vertex_index_expr = nullptr; | ||||||
|     std::function<const ast::Expression*()> instance_index_expr = nullptr; |     std::function<const ast::Expression*()> instance_index_expr = nullptr; | ||||||
| @ -235,7 +277,7 @@ struct State { | |||||||
|     Symbol GetVertexBufferName(uint32_t index) { |     Symbol GetVertexBufferName(uint32_t index) { | ||||||
|         return utils::GetOrCreate(vertex_buffer_names, index, [&] { |         return utils::GetOrCreate(vertex_buffer_names, index, [&] { | ||||||
|             static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_"; |             static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_"; | ||||||
|             return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); |             return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -243,7 +285,7 @@ struct State { | |||||||
|     Symbol GetStructBufferName() { |     Symbol GetStructBufferName() { | ||||||
|         if (!struct_buffer_name.IsValid()) { |         if (!struct_buffer_name.IsValid()) { | ||||||
|             static const char kStructBufferName[] = "tint_vertex_data"; |             static const char kStructBufferName[] = "tint_vertex_data"; | ||||||
|             struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName); |             struct_buffer_name = b.Symbols().New(kStructBufferName); | ||||||
|         } |         } | ||||||
|         return struct_buffer_name; |         return struct_buffer_name; | ||||||
|     } |     } | ||||||
| @ -252,21 +294,19 @@ struct State { | |||||||
|     void AddVertexStorageBuffers() { |     void AddVertexStorageBuffers() { | ||||||
|         // Creating the struct type
 |         // Creating the struct type
 | ||||||
|         static const char kStructName[] = "TintVertexData"; |         static const char kStructName[] = "TintVertexData"; | ||||||
|         auto* struct_type = |         auto* struct_type = b.Structure(b.Symbols().New(kStructName), | ||||||
|             ctx.dst->Structure(ctx.dst->Symbols().New(kStructName), |                                         utils::Vector{ | ||||||
|                                utils::Vector{ |                                             b.Member(GetStructBufferName(), b.ty.array<u32>()), | ||||||
|                                    ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array<u32>()), |                                         }); | ||||||
|                                }); |  | ||||||
|         for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { |         for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { | ||||||
|             // The decorated variable with struct type
 |             // The decorated variable with struct type
 | ||||||
|             ctx.dst->GlobalVar(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type), |             b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage, | ||||||
|                                ast::AddressSpace::kStorage, ast::Access::kRead, |                         ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group))); | ||||||
|                                ctx.dst->Binding(AInt(i)), ctx.dst->Group(AInt(cfg.pulling_group))); |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Creates and returns the assignment to the variables from the buffers
 |     /// Creates and returns the assignment to the variables from the buffers
 | ||||||
|     ast::BlockStatement* CreateVertexPullingPreamble() { |     const ast::BlockStatement* CreateVertexPullingPreamble() { | ||||||
|         // Assign by looking at the vertex descriptor to find attributes with
 |         // Assign by looking at the vertex descriptor to find attributes with
 | ||||||
|         // matching location.
 |         // matching location.
 | ||||||
| 
 | 
 | ||||||
| @ -276,7 +316,7 @@ struct State { | |||||||
|             const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx]; |             const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx]; | ||||||
| 
 | 
 | ||||||
|             if ((buffer_layout.array_stride & 3) != 0) { |             if ((buffer_layout.array_stride & 3) != 0) { | ||||||
|                 ctx.dst->Diagnostics().add_error( |                 b.Diagnostics().add_error( | ||||||
|                     diag::System::Transform, |                     diag::System::Transform, | ||||||
|                     "WebGPU requires that vertex stride must be a multiple of 4 bytes, " |                     "WebGPU requires that vertex stride must be a multiple of 4 bytes, " | ||||||
|                     "but VertexPulling array stride for buffer " + |                     "but VertexPulling array stride for buffer " + | ||||||
| @ -292,15 +332,15 @@ struct State { | |||||||
|             // buffer_array_base is the base array offset for all the vertex
 |             // buffer_array_base is the base array offset for all the vertex
 | ||||||
|             // attributes. These are units of uint (4 bytes).
 |             // attributes. These are units of uint (4 bytes).
 | ||||||
|             auto buffer_array_base = |             auto buffer_array_base = | ||||||
|                 ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); |                 b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); | ||||||
| 
 | 
 | ||||||
|             auto* attribute_offset = index_expr; |             auto* attribute_offset = index_expr; | ||||||
|             if (buffer_layout.array_stride != 4) { |             if (buffer_layout.array_stride != 4) { | ||||||
|                 attribute_offset = ctx.dst->Mul(index_expr, u32(buffer_layout.array_stride / 4u)); |                 attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u)); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // let pulling_offset_n = <attribute_offset>
 |             // let pulling_offset_n = <attribute_offset>
 | ||||||
|             stmts.Push(ctx.dst->Decl(ctx.dst->Let(buffer_array_base, attribute_offset))); |             stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset))); | ||||||
| 
 | 
 | ||||||
|             for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) { |             for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) { | ||||||
|                 auto it = location_info.find(attribute_desc.shader_location); |                 auto it = location_info.find(attribute_desc.shader_location); | ||||||
| @ -320,8 +360,8 @@ struct State { | |||||||
|                     err << "VertexAttributeDescriptor for location " |                     err << "VertexAttributeDescriptor for location " | ||||||
|                         << std::to_string(attribute_desc.shader_location) << " has format " |                         << std::to_string(attribute_desc.shader_location) << " has format " | ||||||
|                         << attribute_desc.format << " but shader expects " |                         << attribute_desc.format << " but shader expects " | ||||||
|                         << var.type->FriendlyName(ctx.src->Symbols()); |                         << var.type->FriendlyName(src->Symbols()); | ||||||
|                     ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str()); |                     b.Diagnostics().add_error(diag::System::Transform, err.str()); | ||||||
|                     return nullptr; |                     return nullptr; | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
| @ -337,16 +377,16 @@ struct State { | |||||||
|                     // WGSL variable vector width is smaller than the loaded vector width
 |                     // WGSL variable vector width is smaller than the loaded vector width
 | ||||||
|                     switch (var_dt.width) { |                     switch (var_dt.width) { | ||||||
|                         case 1: |                         case 1: | ||||||
|                             value = ctx.dst->MemberAccessor(fetch, "x"); |                             value = b.MemberAccessor(fetch, "x"); | ||||||
|                             break; |                             break; | ||||||
|                         case 2: |                         case 2: | ||||||
|                             value = ctx.dst->MemberAccessor(fetch, "xy"); |                             value = b.MemberAccessor(fetch, "xy"); | ||||||
|                             break; |                             break; | ||||||
|                         case 3: |                         case 3: | ||||||
|                             value = ctx.dst->MemberAccessor(fetch, "xyz"); |                             value = b.MemberAccessor(fetch, "xyz"); | ||||||
|                             break; |                             break; | ||||||
|                         default: |                         default: | ||||||
|                             TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width; |                             TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width; | ||||||
|                             return nullptr; |                             return nullptr; | ||||||
|                     } |                     } | ||||||
|                 } else if (var_dt.width > fmt_dt.width) { |                 } else if (var_dt.width > fmt_dt.width) { | ||||||
| @ -355,32 +395,32 @@ struct State { | |||||||
|                     utils::Vector<const ast::Expression*, 8> values{fetch}; |                     utils::Vector<const ast::Expression*, 8> values{fetch}; | ||||||
|                     switch (var_dt.base_type) { |                     switch (var_dt.base_type) { | ||||||
|                         case BaseType::kI32: |                         case BaseType::kI32: | ||||||
|                             ty = ctx.dst->ty.i32(); |                             ty = b.ty.i32(); | ||||||
|                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { |                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | ||||||
|                                 values.Push(ctx.dst->Expr((i == 3) ? 1_i : 0_i)); |                                 values.Push(b.Expr((i == 3) ? 1_i : 0_i)); | ||||||
|                             } |                             } | ||||||
|                             break; |                             break; | ||||||
|                         case BaseType::kU32: |                         case BaseType::kU32: | ||||||
|                             ty = ctx.dst->ty.u32(); |                             ty = b.ty.u32(); | ||||||
|                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { |                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | ||||||
|                                 values.Push(ctx.dst->Expr((i == 3) ? 1_u : 0_u)); |                                 values.Push(b.Expr((i == 3) ? 1_u : 0_u)); | ||||||
|                             } |                             } | ||||||
|                             break; |                             break; | ||||||
|                         case BaseType::kF32: |                         case BaseType::kF32: | ||||||
|                             ty = ctx.dst->ty.f32(); |                             ty = b.ty.f32(); | ||||||
|                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { |                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | ||||||
|                                 values.Push(ctx.dst->Expr((i == 3) ? 1_f : 0_f)); |                                 values.Push(b.Expr((i == 3) ? 1_f : 0_f)); | ||||||
|                             } |                             } | ||||||
|                             break; |                             break; | ||||||
|                         default: |                         default: | ||||||
|                             TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type; |                             TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type; | ||||||
|                             return nullptr; |                             return nullptr; | ||||||
|                     } |                     } | ||||||
|                     value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values); |                     value = b.Construct(b.ty.vec(ty, var_dt.width), values); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 // Assign the value to the WGSL variable
 |                 // Assign the value to the WGSL variable
 | ||||||
|                 stmts.Push(ctx.dst->Assign(var.expr(), value)); |                 stmts.Push(b.Assign(var.expr(), value)); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -388,7 +428,7 @@ struct State { | |||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return ctx.dst->create<ast::BlockStatement>(std::move(stmts)); |         return b.Block(std::move(stmts)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Generates an expression reading from a buffer a specific format.
 |     /// Generates an expression reading from a buffer a specific format.
 | ||||||
| @ -407,7 +447,7 @@ struct State { | |||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Returns a i32 loaded from buffer_base + offset.
 |         // Returns a i32 loaded from buffer_base + offset.
 | ||||||
|         auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); }; |         auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); }; | ||||||
| 
 | 
 | ||||||
|         // Returns a u32 loaded from buffer_base + offset + 4.
 |         // Returns a u32 loaded from buffer_base + offset + 4.
 | ||||||
|         auto load_next_u32 = [&] { |         auto load_next_u32 = [&] { | ||||||
| @ -415,7 +455,7 @@ struct State { | |||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Returns a i32 loaded from buffer_base + offset + 4.
 |         // Returns a i32 loaded from buffer_base + offset + 4.
 | ||||||
|         auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); }; |         auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); }; | ||||||
| 
 | 
 | ||||||
|         // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
 |         // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
 | ||||||
|         // The low 16 bits are 0.
 |         // The low 16 bits are 0.
 | ||||||
| @ -427,17 +467,17 @@ struct State { | |||||||
|                 LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); |                 LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); | ||||||
|             switch (offset & 3) { |             switch (offset & 3) { | ||||||
|                 case 0: |                 case 0: | ||||||
|                     return ctx.dst->Shl(low_u32, 16_u); |                     return b.Shl(low_u32, 16_u); | ||||||
|                 case 1: |                 case 1: | ||||||
|                     return ctx.dst->And(ctx.dst->Shl(low_u32, 8_u), 0xffff0000_u); |                     return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u); | ||||||
|                 case 2: |                 case 2: | ||||||
|                     return ctx.dst->And(low_u32, 0xffff0000_u); |                     return b.And(low_u32, 0xffff0000_u); | ||||||
|                 default: {  // 3:
 |                 default: {  // 3:
 | ||||||
|                     auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, |                     auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, | ||||||
|                                                    VertexFormat::kUint32); |                                                    VertexFormat::kUint32); | ||||||
|                     auto* shr = ctx.dst->Shr(low_u32, 8_u); |                     auto* shr = b.Shr(low_u32, 8_u); | ||||||
|                     auto* shl = ctx.dst->Shl(high_u32, 24_u); |                     auto* shl = b.Shl(high_u32, 24_u); | ||||||
|                     return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000_u); |                     return b.And(b.Or(shl, shr), 0xffff0000_u); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
| @ -450,24 +490,24 @@ struct State { | |||||||
|                 LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); |                 LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); | ||||||
|             switch (offset & 3) { |             switch (offset & 3) { | ||||||
|                 case 0: |                 case 0: | ||||||
|                     return ctx.dst->And(low_u32, 0xffff_u); |                     return b.And(low_u32, 0xffff_u); | ||||||
|                 case 1: |                 case 1: | ||||||
|                     return ctx.dst->And(ctx.dst->Shr(low_u32, 8_u), 0xffff_u); |                     return b.And(b.Shr(low_u32, 8_u), 0xffff_u); | ||||||
|                 case 2: |                 case 2: | ||||||
|                     return ctx.dst->Shr(low_u32, 16_u); |                     return b.Shr(low_u32, 16_u); | ||||||
|                 default: {  // 3:
 |                 default: {  // 3:
 | ||||||
|                     auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, |                     auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, | ||||||
|                                                    VertexFormat::kUint32); |                                                    VertexFormat::kUint32); | ||||||
|                     auto* shr = ctx.dst->Shr(low_u32, 24_u); |                     auto* shr = b.Shr(low_u32, 24_u); | ||||||
|                     auto* shl = ctx.dst->Shl(high_u32, 8_u); |                     auto* shl = b.Shl(high_u32, 8_u); | ||||||
|                     return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff_u); |                     return b.And(b.Or(shl, shr), 0xffff_u); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
 |         // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
 | ||||||
|         // The low 16 bits are 0.
 |         // The low 16 bits are 0.
 | ||||||
|         auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); }; |         auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); }; | ||||||
| 
 | 
 | ||||||
|         // Assumptions are made that alignment must be at least as large as the size
 |         // Assumptions are made that alignment must be at least as large as the size
 | ||||||
|         // of a single component.
 |         // of a single component.
 | ||||||
| @ -480,128 +520,121 @@ struct State { | |||||||
| 
 | 
 | ||||||
|                 // Vectors of basic primitives
 |                 // Vectors of basic primitives
 | ||||||
|             case VertexFormat::kUint32x2: |             case VertexFormat::kUint32x2: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2); | ||||||
|                                VertexFormat::kUint32, 2); |  | ||||||
|             case VertexFormat::kUint32x3: |             case VertexFormat::kUint32x3: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3); | ||||||
|                                VertexFormat::kUint32, 3); |  | ||||||
|             case VertexFormat::kUint32x4: |             case VertexFormat::kUint32x4: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4); | ||||||
|                                VertexFormat::kUint32, 4); |  | ||||||
|             case VertexFormat::kSint32x2: |             case VertexFormat::kSint32x2: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2); | ||||||
|                                VertexFormat::kSint32, 2); |  | ||||||
|             case VertexFormat::kSint32x3: |             case VertexFormat::kSint32x3: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3); | ||||||
|                                VertexFormat::kSint32, 3); |  | ||||||
|             case VertexFormat::kSint32x4: |             case VertexFormat::kSint32x4: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4); | ||||||
|                                VertexFormat::kSint32, 4); |  | ||||||
|             case VertexFormat::kFloat32x2: |             case VertexFormat::kFloat32x2: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, | ||||||
|                                VertexFormat::kFloat32, 2); |                                2); | ||||||
|             case VertexFormat::kFloat32x3: |             case VertexFormat::kFloat32x3: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, | ||||||
|                                VertexFormat::kFloat32, 3); |                                3); | ||||||
|             case VertexFormat::kFloat32x4: |             case VertexFormat::kFloat32x4: | ||||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), |                 return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, | ||||||
|                                VertexFormat::kFloat32, 4); |                                4); | ||||||
| 
 | 
 | ||||||
|             case VertexFormat::kUint8x2: { |             case VertexFormat::kUint8x2: { | ||||||
|                 // yyxx0000, yyxx0000
 |                 // yyxx0000, yyxx0000
 | ||||||
|                 auto* u16s = ctx.dst->vec2<u32>(load_u16_h()); |                 auto* u16s = b.vec2<u32>(load_u16_h()); | ||||||
|                 // xx000000, yyxx0000
 |                 // xx000000, yyxx0000
 | ||||||
|                 auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8_u, 0_u)); |                 auto* shl = b.Shl(u16s, b.vec2<u32>(8_u, 0_u)); | ||||||
|                 // 000000xx, 000000yy
 |                 // 000000xx, 000000yy
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u)); |                 return b.Shr(shl, b.vec2<u32>(24_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kUint8x4: { |             case VertexFormat::kUint8x4: { | ||||||
|                 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
 |                 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
 | ||||||
|                 auto* u32s = ctx.dst->vec4<u32>(load_u32()); |                 auto* u32s = b.vec4<u32>(load_u32()); | ||||||
|                 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
 |                 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
 | ||||||
|                 auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u)); |                 auto* shl = b.Shl(u32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u)); | ||||||
|                 // 000000xx, 000000yy, 000000zz, 000000ww
 |                 // 000000xx, 000000yy, 000000zz, 000000ww
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u)); |                 return b.Shr(shl, b.vec4<u32>(24_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kUint16x2: { |             case VertexFormat::kUint16x2: { | ||||||
|                 // yyyyxxxx, yyyyxxxx
 |                 // yyyyxxxx, yyyyxxxx
 | ||||||
|                 auto* u32s = ctx.dst->vec2<u32>(load_u32()); |                 auto* u32s = b.vec2<u32>(load_u32()); | ||||||
|                 // xxxx0000, yyyyxxxx
 |                 // xxxx0000, yyyyxxxx
 | ||||||
|                 auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16_u, 0_u)); |                 auto* shl = b.Shl(u32s, b.vec2<u32>(16_u, 0_u)); | ||||||
|                 // 0000xxxx, 0000yyyy
 |                 // 0000xxxx, 0000yyyy
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u)); |                 return b.Shr(shl, b.vec2<u32>(16_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kUint16x4: { |             case VertexFormat::kUint16x4: { | ||||||
|                 // yyyyxxxx, wwwwzzzz
 |                 // yyyyxxxx, wwwwzzzz
 | ||||||
|                 auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32()); |                 auto* u32s = b.vec2<u32>(load_u32(), load_next_u32()); | ||||||
|                 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
 |                 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
 | ||||||
|                 auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy"); |                 auto* xxyy = b.MemberAccessor(u32s, "xxyy"); | ||||||
|                 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
 |                 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
 | ||||||
|                 auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u)); |                 auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u)); | ||||||
|                 // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
 |                 // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u)); |                 return b.Shr(shl, b.vec4<u32>(16_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kSint8x2: { |             case VertexFormat::kSint8x2: { | ||||||
|                 // yyxx0000, yyxx0000
 |                 // yyxx0000, yyxx0000
 | ||||||
|                 auto* i16s = ctx.dst->vec2<i32>(load_i16_h()); |                 auto* i16s = b.vec2<i32>(load_i16_h()); | ||||||
|                 // xx000000, yyxx0000
 |                 // xx000000, yyxx0000
 | ||||||
|                 auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8_u, 0_u)); |                 auto* shl = b.Shl(i16s, b.vec2<u32>(8_u, 0_u)); | ||||||
|                 // ssssssxx, ssssssyy
 |                 // ssssssxx, ssssssyy
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u)); |                 return b.Shr(shl, b.vec2<u32>(24_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kSint8x4: { |             case VertexFormat::kSint8x4: { | ||||||
|                 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
 |                 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
 | ||||||
|                 auto* i32s = ctx.dst->vec4<i32>(load_i32()); |                 auto* i32s = b.vec4<i32>(load_i32()); | ||||||
|                 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
 |                 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
 | ||||||
|                 auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u)); |                 auto* shl = b.Shl(i32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u)); | ||||||
|                 // ssssssxx, ssssssyy, sssssszz, ssssssww
 |                 // ssssssxx, ssssssyy, sssssszz, ssssssww
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u)); |                 return b.Shr(shl, b.vec4<u32>(24_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kSint16x2: { |             case VertexFormat::kSint16x2: { | ||||||
|                 // yyyyxxxx, yyyyxxxx
 |                 // yyyyxxxx, yyyyxxxx
 | ||||||
|                 auto* i32s = ctx.dst->vec2<i32>(load_i32()); |                 auto* i32s = b.vec2<i32>(load_i32()); | ||||||
|                 // xxxx0000, yyyyxxxx
 |                 // xxxx0000, yyyyxxxx
 | ||||||
|                 auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16_u, 0_u)); |                 auto* shl = b.Shl(i32s, b.vec2<u32>(16_u, 0_u)); | ||||||
|                 // ssssxxxx, ssssyyyy
 |                 // ssssxxxx, ssssyyyy
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u)); |                 return b.Shr(shl, b.vec2<u32>(16_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kSint16x4: { |             case VertexFormat::kSint16x4: { | ||||||
|                 // yyyyxxxx, wwwwzzzz
 |                 // yyyyxxxx, wwwwzzzz
 | ||||||
|                 auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32()); |                 auto* i32s = b.vec2<i32>(load_i32(), load_next_i32()); | ||||||
|                 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
 |                 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
 | ||||||
|                 auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy"); |                 auto* xxyy = b.MemberAccessor(i32s, "xxyy"); | ||||||
|                 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
 |                 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
 | ||||||
|                 auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u)); |                 auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u)); | ||||||
|                 // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
 |                 // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
 | ||||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u)); |                 return b.Shr(shl, b.vec4<u32>(16_u)); | ||||||
|             } |             } | ||||||
|             case VertexFormat::kUnorm8x2: |             case VertexFormat::kUnorm8x2: | ||||||
|                 return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy"); |                 return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy"); | ||||||
|             case VertexFormat::kSnorm8x2: |             case VertexFormat::kSnorm8x2: | ||||||
|                 return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy"); |                 return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy"); | ||||||
|             case VertexFormat::kUnorm8x4: |             case VertexFormat::kUnorm8x4: | ||||||
|                 return ctx.dst->Call("unpack4x8unorm", load_u32()); |                 return b.Call("unpack4x8unorm", load_u32()); | ||||||
|             case VertexFormat::kSnorm8x4: |             case VertexFormat::kSnorm8x4: | ||||||
|                 return ctx.dst->Call("unpack4x8snorm", load_u32()); |                 return b.Call("unpack4x8snorm", load_u32()); | ||||||
|             case VertexFormat::kUnorm16x2: |             case VertexFormat::kUnorm16x2: | ||||||
|                 return ctx.dst->Call("unpack2x16unorm", load_u32()); |                 return b.Call("unpack2x16unorm", load_u32()); | ||||||
|             case VertexFormat::kSnorm16x2: |             case VertexFormat::kSnorm16x2: | ||||||
|                 return ctx.dst->Call("unpack2x16snorm", load_u32()); |                 return b.Call("unpack2x16snorm", load_u32()); | ||||||
|             case VertexFormat::kFloat16x2: |             case VertexFormat::kFloat16x2: | ||||||
|                 return ctx.dst->Call("unpack2x16float", load_u32()); |                 return b.Call("unpack2x16float", load_u32()); | ||||||
|             case VertexFormat::kUnorm16x4: |             case VertexFormat::kUnorm16x4: | ||||||
|                 return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16unorm", load_u32()), |                 return b.vec4<f32>(b.Call("unpack2x16unorm", load_u32()), | ||||||
|                                           ctx.dst->Call("unpack2x16unorm", load_next_u32())); |                                    b.Call("unpack2x16unorm", load_next_u32())); | ||||||
|             case VertexFormat::kSnorm16x4: |             case VertexFormat::kSnorm16x4: | ||||||
|                 return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16snorm", load_u32()), |                 return b.vec4<f32>(b.Call("unpack2x16snorm", load_u32()), | ||||||
|                                           ctx.dst->Call("unpack2x16snorm", load_next_u32())); |                                    b.Call("unpack2x16snorm", load_next_u32())); | ||||||
|             case VertexFormat::kFloat16x4: |             case VertexFormat::kFloat16x4: | ||||||
|                 return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16float", load_u32()), |                 return b.vec4<f32>(b.Call("unpack2x16float", load_u32()), | ||||||
|                                           ctx.dst->Call("unpack2x16float", load_next_u32())); |                                    b.Call("unpack2x16float", load_next_u32())); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) |         TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format); | ||||||
|             << "format " << static_cast<int>(format); |  | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -623,12 +656,12 @@ struct State { | |||||||
| 
 | 
 | ||||||
|             const ast ::Expression* index = nullptr; |             const ast ::Expression* index = nullptr; | ||||||
|             if (offset > 0) { |             if (offset > 0) { | ||||||
|                 index = ctx.dst->Add(array_base, u32(offset / 4)); |                 index = b.Add(array_base, u32(offset / 4)); | ||||||
|             } else { |             } else { | ||||||
|                 index = ctx.dst->Expr(array_base); |                 index = b.Expr(array_base); | ||||||
|             } |             } | ||||||
|             u = ctx.dst->IndexAccessor( |             u = b.IndexAccessor( | ||||||
|                 ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); |                 b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); | ||||||
| 
 | 
 | ||||||
|         } else { |         } else { | ||||||
|             // Unaligned load
 |             // Unaligned load
 | ||||||
| @ -639,22 +672,22 @@ struct State { | |||||||
| 
 | 
 | ||||||
|             uint32_t shift = 8u * (offset & 3u); |             uint32_t shift = 8u * (offset & 3u); | ||||||
| 
 | 
 | ||||||
|             auto* low_shr = ctx.dst->Shr(low, u32(shift)); |             auto* low_shr = b.Shr(low, u32(shift)); | ||||||
|             auto* high_shl = ctx.dst->Shl(high, u32(32u - shift)); |             auto* high_shl = b.Shl(high, u32(32u - shift)); | ||||||
|             u = ctx.dst->Or(low_shr, high_shl); |             u = b.Or(low_shr, high_shl); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         switch (format) { |         switch (format) { | ||||||
|             case VertexFormat::kUint32: |             case VertexFormat::kUint32: | ||||||
|                 return u; |                 return u; | ||||||
|             case VertexFormat::kSint32: |             case VertexFormat::kSint32: | ||||||
|                 return ctx.dst->Bitcast(ctx.dst->ty.i32(), u); |                 return b.Bitcast(b.ty.i32(), u); | ||||||
|             case VertexFormat::kFloat32: |             case VertexFormat::kFloat32: | ||||||
|                 return ctx.dst->Bitcast(ctx.dst->ty.f32(), u); |                 return b.Bitcast(b.ty.f32(), u); | ||||||
|             default: |             default: | ||||||
|                 break; |                 break; | ||||||
|         } |         } | ||||||
|         TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) |         TINT_UNREACHABLE(Transform, b.Diagnostics()) | ||||||
|             << "invalid format for LoadPrimitive" << static_cast<int>(format); |             << "invalid format for LoadPrimitive" << static_cast<int>(format); | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     } |     } | ||||||
| @ -682,8 +715,7 @@ struct State { | |||||||
|             expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format)); |             expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count), |         return b.Construct(b.create<ast::Vector>(base_type, count), std::move(expr_list)); | ||||||
|                                   std::move(expr_list)); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Process a non-struct entry point parameter.
 |     /// Process a non-struct entry point parameter.
 | ||||||
| @ -696,34 +728,30 @@ struct State { | |||||||
|             // Create a function-scope variable to replace the parameter.
 |             // Create a function-scope variable to replace the parameter.
 | ||||||
|             auto func_var_sym = ctx.Clone(param->symbol); |             auto func_var_sym = ctx.Clone(param->symbol); | ||||||
|             auto* func_var_type = ctx.Clone(param->type); |             auto* func_var_type = ctx.Clone(param->type); | ||||||
|             auto* func_var = ctx.dst->Var(func_var_sym, func_var_type); |             auto* func_var = b.Var(func_var_sym, func_var_type); | ||||||
|             ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var)); |             ctx.InsertFront(func->body->statements, b.Decl(func_var)); | ||||||
|             // Capture mapping from location to the new variable.
 |             // Capture mapping from location to the new variable.
 | ||||||
|             LocationInfo info; |             LocationInfo info; | ||||||
|             info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); }; |             info.expr = [this, func_var]() { return b.Expr(func_var); }; | ||||||
| 
 | 
 | ||||||
|             auto* sem = ctx.src->Sem().Get<sem::Parameter>(param); |             auto* sem = src->Sem().Get<sem::Parameter>(param); | ||||||
|             info.type = sem->Type(); |             info.type = sem->Type(); | ||||||
| 
 | 
 | ||||||
|             if (!sem->Location().has_value()) { |             if (!sem->Location().has_value()) { | ||||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value"; |                 TINT_ICE(Transform, b.Diagnostics()) << "Location missing value"; | ||||||
|                 return; |                 return; | ||||||
|             } |             } | ||||||
|             location_info[sem->Location().value()] = info; |             location_info[sem->Location().value()] = info; | ||||||
|         } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) { |         } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) { | ||||||
|             // Check for existing vertex_index and instance_index builtins.
 |             // Check for existing vertex_index and instance_index builtins.
 | ||||||
|             if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { |             if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { | ||||||
|                 vertex_index_expr = [this, param]() { |                 vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; | ||||||
|                     return ctx.dst->Expr(ctx.Clone(param->symbol)); |  | ||||||
|                 }; |  | ||||||
|             } else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) { |             } else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) { | ||||||
|                 instance_index_expr = [this, param]() { |                 instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; | ||||||
|                     return ctx.dst->Expr(ctx.Clone(param->symbol)); |  | ||||||
|                 }; |  | ||||||
|             } |             } | ||||||
|             new_function_parameters.Push(ctx.Clone(param)); |             new_function_parameters.Push(ctx.Clone(param)); | ||||||
|         } else { |         } else { | ||||||
|             TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter"; |             TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -746,7 +774,7 @@ struct State { | |||||||
|         for (auto* member : struct_ty->members) { |         for (auto* member : struct_ty->members) { | ||||||
|             auto member_sym = ctx.Clone(member->symbol); |             auto member_sym = ctx.Clone(member->symbol); | ||||||
|             std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() { |             std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() { | ||||||
|                 return ctx.dst->MemberAccessor(param_sym, member_sym); |                 return b.MemberAccessor(param_sym, member_sym); | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|             if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) { |             if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) { | ||||||
| @ -754,7 +782,7 @@ struct State { | |||||||
|                 LocationInfo info; |                 LocationInfo info; | ||||||
|                 info.expr = member_expr; |                 info.expr = member_expr; | ||||||
| 
 | 
 | ||||||
|                 auto* sem = ctx.src->Sem().Get(member); |                 auto* sem = src->Sem().Get(member); | ||||||
|                 info.type = sem->Type(); |                 info.type = sem->Type(); | ||||||
| 
 | 
 | ||||||
|                 TINT_ASSERT(Transform, sem->Location().has_value()); |                 TINT_ASSERT(Transform, sem->Location().has_value()); | ||||||
| @ -770,7 +798,7 @@ struct State { | |||||||
|                 } |                 } | ||||||
|                 members_to_clone.Push(member); |                 members_to_clone.Push(member); | ||||||
|             } else { |             } else { | ||||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter"; |                 TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -781,8 +809,8 @@ struct State { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // Create a function-scope variable to replace the parameter.
 |         // Create a function-scope variable to replace the parameter.
 | ||||||
|         auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type)); |         auto* func_var = b.Var(param_sym, ctx.Clone(param->type)); | ||||||
|         ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var)); |         ctx.InsertFront(func->body->statements, b.Decl(func_var)); | ||||||
| 
 | 
 | ||||||
|         if (!members_to_clone.IsEmpty()) { |         if (!members_to_clone.IsEmpty()) { | ||||||
|             // Create a new struct without the location attributes.
 |             // Create a new struct without the location attributes.
 | ||||||
| @ -791,20 +819,20 @@ struct State { | |||||||
|                 auto member_sym = ctx.Clone(member->symbol); |                 auto member_sym = ctx.Clone(member->symbol); | ||||||
|                 auto* member_type = ctx.Clone(member->type); |                 auto* member_type = ctx.Clone(member->type); | ||||||
|                 auto member_attrs = ctx.Clone(member->attributes); |                 auto member_attrs = ctx.Clone(member->attributes); | ||||||
|                 new_members.Push(ctx.dst->Member(member_sym, member_type, std::move(member_attrs))); |                 new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs))); | ||||||
|             } |             } | ||||||
|             auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members); |             auto* new_struct = b.Structure(b.Sym(), new_members); | ||||||
| 
 | 
 | ||||||
|             // Create a new function parameter with this struct.
 |             // Create a new function parameter with this struct.
 | ||||||
|             auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct)); |             auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct)); | ||||||
|             new_function_parameters.Push(new_param); |             new_function_parameters.Push(new_param); | ||||||
| 
 | 
 | ||||||
|             // Copy values from the new parameter to the function-scope variable.
 |             // Copy values from the new parameter to the function-scope variable.
 | ||||||
|             for (auto* member : members_to_clone) { |             for (auto* member : members_to_clone) { | ||||||
|                 auto member_name = ctx.Clone(member->symbol); |                 auto member_name = ctx.Clone(member->symbol); | ||||||
|                 ctx.InsertFront(func->body->statements, |                 ctx.InsertFront(func->body->statements, | ||||||
|                                 ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name), |                                 b.Assign(b.MemberAccessor(func_var, member_name), | ||||||
|                                                 ctx.dst->MemberAccessor(new_param, member_name))); |                                          b.MemberAccessor(new_param, member_name))); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -818,7 +846,7 @@ struct State { | |||||||
| 
 | 
 | ||||||
|         // Process entry point parameters.
 |         // Process entry point parameters.
 | ||||||
|         for (auto* param : func->params) { |         for (auto* param : func->params) { | ||||||
|             auto* sem = ctx.src->Sem().Get(param); |             auto* sem = src->Sem().Get(param); | ||||||
|             if (auto* str = sem->Type()->As<sem::Struct>()) { |             if (auto* str = sem->Type()->As<sem::Struct>()) { | ||||||
|                 ProcessStructParameter(func, param, str->Declaration()); |                 ProcessStructParameter(func, param, str->Declaration()); | ||||||
|             } else { |             } else { | ||||||
| @ -830,11 +858,11 @@ struct State { | |||||||
|         if (!vertex_index_expr) { |         if (!vertex_index_expr) { | ||||||
|             for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { |             for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { | ||||||
|                 if (layout.step_mode == VertexStepMode::kVertex) { |                 if (layout.step_mode == VertexStepMode::kVertex) { | ||||||
|                     auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index"); |                     auto name = b.Symbols().New("tint_pulling_vertex_index"); | ||||||
|                     new_function_parameters.Push(ctx.dst->Param( |                     new_function_parameters.Push( | ||||||
|                         name, ctx.dst->ty.u32(), |                         b.Param(name, b.ty.u32(), | ||||||
|                         utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kVertexIndex)})); |                                 utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)})); | ||||||
|                     vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); }; |                     vertex_index_expr = [this, name]() { return b.Expr(name); }; | ||||||
|                     break; |                     break; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @ -842,11 +870,11 @@ struct State { | |||||||
|         if (!instance_index_expr) { |         if (!instance_index_expr) { | ||||||
|             for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { |             for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { | ||||||
|                 if (layout.step_mode == VertexStepMode::kInstance) { |                 if (layout.step_mode == VertexStepMode::kInstance) { | ||||||
|                     auto name = ctx.dst->Symbols().New("tint_pulling_instance_index"); |                     auto name = b.Symbols().New("tint_pulling_instance_index"); | ||||||
|                     new_function_parameters.Push(ctx.dst->Param( |                     new_function_parameters.Push( | ||||||
|                         name, ctx.dst->ty.u32(), |                         b.Param(name, b.ty.u32(), | ||||||
|                         utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kInstanceIndex)})); |                                 utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)})); | ||||||
|                     instance_index_expr = [this, name]() { return ctx.dst->Expr(name); }; |                     instance_index_expr = [this, name]() { return b.Expr(name); }; | ||||||
|                     break; |                     break; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @ -864,53 +892,24 @@ struct State { | |||||||
|         auto attrs = ctx.Clone(func->attributes); |         auto attrs = ctx.Clone(func->attributes); | ||||||
|         auto ret_attrs = ctx.Clone(func->return_type_attributes); |         auto ret_attrs = ctx.Clone(func->return_type_attributes); | ||||||
|         auto* new_func = |         auto* new_func = | ||||||
|             ctx.dst->create<ast::Function>(func->source, func_sym, new_function_parameters, |             b.create<ast::Function>(func->source, func_sym, new_function_parameters, ret_type, body, | ||||||
|                                            ret_type, body, std::move(attrs), std::move(ret_attrs)); |                                     std::move(attrs), std::move(ret_attrs)); | ||||||
|         ctx.Replace(func, new_func); |         ctx.Replace(func, new_func); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| VertexPulling::VertexPulling() = default; | VertexPulling::VertexPulling() = default; | ||||||
| VertexPulling::~VertexPulling() = default; | VertexPulling::~VertexPulling() = default; | ||||||
| 
 | 
 | ||||||
| void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | Transform::ApplyResult VertexPulling::Apply(const Program* src, | ||||||
|  |                                             const DataMap& inputs, | ||||||
|  |                                             DataMap&) const { | ||||||
|     auto cfg = cfg_; |     auto cfg = cfg_; | ||||||
|     if (auto* cfg_data = inputs.Get<Config>()) { |     if (auto* cfg_data = inputs.Get<Config>()) { | ||||||
|         cfg = *cfg_data; |         cfg = *cfg_data; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Find entry point
 |     return State{src, cfg}.Run(); | ||||||
|     const ast::Function* func = nullptr; |  | ||||||
|     for (auto* fn : ctx.src->AST().Functions()) { |  | ||||||
|         if (fn->PipelineStage() == ast::PipelineStage::kVertex) { |  | ||||||
|             if (func != nullptr) { |  | ||||||
|                 ctx.dst->Diagnostics().add_error( |  | ||||||
|                     diag::System::Transform, |  | ||||||
|                     "VertexPulling found more than one vertex entry point"); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|             func = fn; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     if (func == nullptr) { |  | ||||||
|         ctx.dst->Diagnostics().add_error(diag::System::Transform, |  | ||||||
|                                          "Vertex stage entry point not found"); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // TODO(idanr): Need to check shader locations in descriptor cover all
 |  | ||||||
|     // attributes
 |  | ||||||
| 
 |  | ||||||
|     // TODO(idanr): Make sure we covered all error cases, to guarantee the
 |  | ||||||
|     // following stages will pass
 |  | ||||||
| 
 |  | ||||||
|     State state{ctx, cfg}; |  | ||||||
|     state.AddVertexStorageBuffers(); |  | ||||||
|     state.Process(func); |  | ||||||
| 
 |  | ||||||
|     ctx.Clone(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| VertexPulling::Config::Config() = default; | VertexPulling::Config::Config() = default; | ||||||
|  | |||||||
| @ -171,16 +171,14 @@ class VertexPulling final : public Castable<VertexPulling, Transform> { | |||||||
|     /// Destructor
 |     /// Destructor
 | ||||||
|     ~VertexPulling() override; |     ~VertexPulling() override; | ||||||
| 
 | 
 | ||||||
|   protected: |     /// @copydoc Transform::Apply
 | ||||||
|     /// Runs the transform using the CloneContext built for transforming a
 |     ApplyResult Apply(const Program* program, | ||||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 |                       const DataMap& inputs, | ||||||
|     /// @param ctx the CloneContext primed with the input program and
 |                       DataMap& outputs) const override; | ||||||
|     /// ProgramBuilder
 |  | ||||||
|     /// @param inputs optional extra transform-specific input data
 |  | ||||||
|     /// @param outputs optional extra transform-specific output data
 |  | ||||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; |  | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|  |     struct State; | ||||||
|  | 
 | ||||||
|     Config cfg_; |     Config cfg_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -14,18 +14,17 @@ | |||||||
| 
 | 
 | ||||||
| #include "src/tint/transform/while_to_loop.h" | #include "src/tint/transform/while_to_loop.h" | ||||||
| 
 | 
 | ||||||
|  | #include <utility> | ||||||
|  | 
 | ||||||
| #include "src/tint/ast/break_statement.h" | #include "src/tint/ast/break_statement.h" | ||||||
| #include "src/tint/program_builder.h" | #include "src/tint/program_builder.h" | ||||||
| 
 | 
 | ||||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop); | TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop); | ||||||
| 
 | 
 | ||||||
| namespace tint::transform { | namespace tint::transform { | ||||||
|  | namespace { | ||||||
| 
 | 
 | ||||||
| WhileToLoop::WhileToLoop() = default; | bool ShouldRun(const Program* program) { | ||||||
| 
 |  | ||||||
| WhileToLoop::~WhileToLoop() = default; |  | ||||||
| 
 |  | ||||||
| bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const { |  | ||||||
|     for (auto* node : program->ASTNodes().Objects()) { |     for (auto* node : program->ASTNodes().Objects()) { | ||||||
|         if (node->Is<ast::WhileStatement>()) { |         if (node->Is<ast::WhileStatement>()) { | ||||||
|             return true; |             return true; | ||||||
| @ -34,20 +33,32 @@ bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const { | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | WhileToLoop::WhileToLoop() = default; | ||||||
|  | 
 | ||||||
|  | WhileToLoop::~WhileToLoop() = default; | ||||||
|  | 
 | ||||||
|  | Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||||
|  |     if (!ShouldRun(src)) { | ||||||
|  |         return SkipTransform; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ProgramBuilder b; | ||||||
|  |     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||||
|  | 
 | ||||||
|     ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* { |     ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* { | ||||||
|         utils::Vector<const ast::Statement*, 16> stmts; |         utils::Vector<const ast::Statement*, 16> stmts; | ||||||
|         auto* cond = w->condition; |         auto* cond = w->condition; | ||||||
| 
 | 
 | ||||||
|         // !condition
 |         // !condition
 | ||||||
|         auto* not_cond = |         auto* not_cond = b.Not(ctx.Clone(cond)); | ||||||
|             ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond)); |  | ||||||
| 
 | 
 | ||||||
|         // { break; }
 |         // { break; }
 | ||||||
|         auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>()); |         auto* break_body = b.Block(b.Break()); | ||||||
| 
 | 
 | ||||||
|         // if (!condition) { break; }
 |         // if (!condition) { break; }
 | ||||||
|         stmts.Push(ctx.dst->If(not_cond, break_body)); |         stmts.Push(b.If(not_cond, break_body)); | ||||||
| 
 | 
 | ||||||
|         for (auto* stmt : w->body->statements) { |         for (auto* stmt : w->body->statements) { | ||||||
|             stmts.Push(ctx.Clone(stmt)); |             stmts.Push(ctx.Clone(stmt)); | ||||||
| @ -55,13 +66,14 @@ void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | |||||||
| 
 | 
 | ||||||
|         const ast::BlockStatement* continuing = nullptr; |         const ast::BlockStatement* continuing = nullptr; | ||||||
| 
 | 
 | ||||||
|         auto* body = ctx.dst->Block(stmts); |         auto* body = b.Block(stmts); | ||||||
|         auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing); |         auto* loop = b.Loop(body, continuing); | ||||||
| 
 | 
 | ||||||
|         return loop; |         return loop; | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     ctx.Clone(); |     ctx.Clone(); | ||||||
|  |     return Program(std::move(b)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tint::transform
 | }  // namespace tint::transform
 | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user