mirror of
				https://github.com/encounter/dawn-cmake.git
				synced 2025-10-26 11:40:29 +00:00 
			
		
		
		
	tint/transform: Refactor transforms
Replace the ShouldRun() method with Apply() which will do the transformation if it needs to be done, otherwise returns 'SkipTransform'. This reduces a bunch of duplicated scanning between the old ShouldRun() and Transform(). This change also adjusts code style to make the transforms more consistent. Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681 Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
		
							parent
							
								
									de6db384aa
								
							
						
					
					
						commit
						c6b381495d
					
				| @ -15,6 +15,7 @@ | ||||
| #include "src/tint/fuzzers/shuffle_transform.h" | ||||
| 
 | ||||
| #include <random> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/program_builder.h" | ||||
| 
 | ||||
| @ -22,15 +23,21 @@ namespace tint::fuzzers { | ||||
| 
 | ||||
| ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {} | ||||
| 
 | ||||
| void ShuffleTransform::Run(CloneContext& ctx, | ||||
|                            const tint::transform::DataMap&, | ||||
|                            tint::transform::DataMap&) const { | ||||
|     auto decls = ctx.src->AST().GlobalDeclarations(); | ||||
| transform::Transform::ApplyResult ShuffleTransform::Apply(const Program* src, | ||||
|                                                           const transform::DataMap&, | ||||
|                                                           transform::DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     auto decls = src->AST().GlobalDeclarations(); | ||||
|     auto rng = std::mt19937_64{seed_}; | ||||
|     std::shuffle(std::begin(decls), std::end(decls), rng); | ||||
|     for (auto* decl : decls) { | ||||
|         ctx.dst->AST().AddGlobalDeclaration(ctx.Clone(decl)); | ||||
|         b.AST().AddGlobalDeclaration(ctx.Clone(decl)); | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::fuzzers
 | ||||
|  | ||||
| @ -20,16 +20,16 @@ | ||||
| namespace tint::fuzzers { | ||||
| 
 | ||||
| /// ShuffleTransform reorders the module scope declarations into a random order
 | ||||
| class ShuffleTransform : public tint::transform::Transform { | ||||
| class ShuffleTransform : public transform::Transform { | ||||
|   public: | ||||
|     /// Constructor
 | ||||
|     /// @param seed the random seed to use for the shuffling
 | ||||
|     explicit ShuffleTransform(size_t seed); | ||||
| 
 | ||||
|   protected: | ||||
|     void Run(CloneContext& ctx, | ||||
|              const tint::transform::DataMap&, | ||||
|              tint::transform::DataMap&) const override; | ||||
|     /// @copydoc transform::Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const transform::DataMap& inputs, | ||||
|                       transform::DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     size_t seed_; | ||||
|  | ||||
| @ -23,6 +23,7 @@ | ||||
| 
 | ||||
| #include "src/tint/ast/access.h" | ||||
| #include "src/tint/ast/address_space.h" | ||||
| #include "src/tint/ast/parameter.h" | ||||
| #include "src/tint/sem/binding_point.h" | ||||
| #include "src/tint/sem/expression.h" | ||||
| #include "src/tint/sem/parameter_usage.h" | ||||
| @ -212,6 +213,11 @@ class Parameter final : public Castable<Parameter, Variable> { | ||||
|     /// Destructor
 | ||||
|     ~Parameter() override; | ||||
| 
 | ||||
|     /// @returns the AST declaration node
 | ||||
|     const ast::Parameter* Declaration() const { | ||||
|         return static_cast<const ast::Parameter*>(Variable::Declaration()); | ||||
|     } | ||||
| 
 | ||||
|     /// @return the index of the parmeter in the function
 | ||||
|     uint32_t Index() const { return index_; } | ||||
| 
 | ||||
|  | ||||
| @ -31,21 +31,29 @@ AddBlockAttribute::AddBlockAttribute() = default; | ||||
| 
 | ||||
| AddBlockAttribute::~AddBlockAttribute() = default; | ||||
| 
 | ||||
| void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     auto& sem = ctx.src->Sem(); | ||||
| Transform::ApplyResult AddBlockAttribute::Apply(const Program* src, | ||||
|                                                 const DataMap&, | ||||
|                                                 DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     auto& sem = src->Sem(); | ||||
| 
 | ||||
|     // A map from a type in the source program to a block-decorated wrapper that contains it in the
 | ||||
|     // destination program.
 | ||||
|     utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs; | ||||
| 
 | ||||
|     // Process global 'var' declarations that are buffers.
 | ||||
|     for (auto* global : ctx.src->AST().GlobalVariables()) { | ||||
|     bool made_changes = false; | ||||
|     for (auto* global : src->AST().GlobalVariables()) { | ||||
|         auto* var = sem.Get(global); | ||||
|         if (!ast::IsHostShareable(var->AddressSpace())) { | ||||
|             // Not declared in a host-sharable address space
 | ||||
|             continue; | ||||
|         } | ||||
| 
 | ||||
|         made_changes = true; | ||||
| 
 | ||||
|         auto* ty = var->Type()->UnwrapRef(); | ||||
|         auto* str = ty->As<sem::Struct>(); | ||||
| 
 | ||||
| @ -61,33 +69,36 @@ void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|             const char* kMemberName = "inner"; | ||||
| 
 | ||||
|             auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] { | ||||
|                 auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(), | ||||
|                                                                          ctx.dst->AllocateNodeID()); | ||||
|                 auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block"; | ||||
|                 auto* ret = ctx.dst->create<ast::Struct>( | ||||
|                     ctx.dst->Symbols().New(wrapper_name), | ||||
|                     utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))}, | ||||
|                 auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID()); | ||||
|                 auto wrapper_name = src->Symbols().NameFor(global->symbol) + "_block"; | ||||
|                 auto* ret = b.create<ast::Struct>( | ||||
|                     b.Symbols().New(wrapper_name), | ||||
|                     utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))}, | ||||
|                     utils::Vector{block}); | ||||
|                 ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret); | ||||
|                 ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret); | ||||
|                 return ret; | ||||
|             }); | ||||
|             ctx.Replace(global->type, ctx.dst->ty.Of(wrapper)); | ||||
|             ctx.Replace(global->type, b.ty.Of(wrapper)); | ||||
| 
 | ||||
|             // Insert a member accessor to get the original type from the wrapper at
 | ||||
|             // any usage of the original variable.
 | ||||
|             for (auto* user : var->Users()) { | ||||
|                 ctx.Replace(user->Declaration(), | ||||
|                             ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName)); | ||||
|                             b.MemberAccessor(ctx.Clone(global->symbol), kMemberName)); | ||||
|             } | ||||
|         } else { | ||||
|             // Add a block attribute to this struct directly.
 | ||||
|             auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(), | ||||
|                                                                      ctx.dst->AllocateNodeID()); | ||||
|             auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID()); | ||||
|             ctx.InsertFront(str->Declaration()->attributes, block); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (!made_changes) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid) | ||||
|  | ||||
| @ -53,14 +53,10 @@ class AddBlockAttribute final : public Castable<AddBlockAttribute, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~AddBlockAttribute() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -23,12 +23,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint); | ||||
| using namespace tint::number_suffixes;  // NOLINT
 | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| AddEmptyEntryPoint::AddEmptyEntryPoint() = default; | ||||
| 
 | ||||
| AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; | ||||
| 
 | ||||
| bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* func : program->AST().Functions()) { | ||||
|         if (func->IsEntryPoint()) { | ||||
|             return false; | ||||
| @ -37,13 +34,30 @@ bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
| void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {}, | ||||
|                   utils::Vector{ | ||||
|                       ctx.dst->Stage(ast::PipelineStage::kCompute), | ||||
|                       ctx.dst->WorkgroupSize(1_i), | ||||
|                   }); | ||||
| }  // namespace
 | ||||
| 
 | ||||
| AddEmptyEntryPoint::AddEmptyEntryPoint() = default; | ||||
| 
 | ||||
| AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; | ||||
| 
 | ||||
| Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src, | ||||
|                                                  const DataMap&, | ||||
|                                                  DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {}, | ||||
|            utils::Vector{ | ||||
|                b.Stage(ast::PipelineStage::kCompute), | ||||
|                b.WorkgroupSize(1_i), | ||||
|            }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -27,19 +27,10 @@ class AddEmptyEntryPoint final : public Castable<AddEmptyEntryPoint, Transform> | ||||
|     /// Destructor
 | ||||
|     ~AddEmptyEntryPoint() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -31,13 +31,153 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (auto* sem_fn = program->Sem().Get(fn)) { | ||||
|             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { | ||||
|                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { | ||||
|                     return true; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| ArrayLengthFromUniform::ArrayLengthFromUniform() = default; | ||||
| ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; | ||||
| 
 | ||||
| /// The PIMPL state for this transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct ArrayLengthFromUniform::State { | ||||
|     /// Constructor
 | ||||
|     /// @param program the source program
 | ||||
|     /// @param in the input transform data
 | ||||
|     /// @param out the output transform data
 | ||||
|     explicit State(const Program* program, const DataMap& in, DataMap& out) | ||||
|         : src(program), inputs(in), outputs(out) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         auto* cfg = inputs.Get<Config>(); | ||||
|         if (cfg == nullptr) { | ||||
|             b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                       "missing transform data for " + | ||||
|                                           std::string(TypeInfo::Of<ArrayLengthFromUniform>().name)); | ||||
|             return Program(std::move(b)); | ||||
|         } | ||||
| 
 | ||||
|         if (!ShouldRun(ctx.src)) { | ||||
|             return SkipTransform; | ||||
|         } | ||||
| 
 | ||||
|         const char* kBufferSizeMemberName = "buffer_size"; | ||||
| 
 | ||||
|         // Determine the size of the buffer size array.
 | ||||
|         uint32_t max_buffer_size_index = 0; | ||||
| 
 | ||||
|         IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*, | ||||
|                                            const sem::GlobalVariable* var) { | ||||
|             auto binding = var->BindingPoint(); | ||||
|             auto idx_itr = cfg->bindpoint_to_size_index.find(binding); | ||||
|             if (idx_itr == cfg->bindpoint_to_size_index.end()) { | ||||
|                 return; | ||||
|             } | ||||
|             if (idx_itr->second > max_buffer_size_index) { | ||||
|                 max_buffer_size_index = idx_itr->second; | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         // Get (or create, on first call) the uniform buffer that will receive the
 | ||||
|         // size of each storage buffer in the module.
 | ||||
|         const ast::Variable* buffer_size_ubo = nullptr; | ||||
|         auto get_ubo = [&]() { | ||||
|             if (!buffer_size_ubo) { | ||||
|                 // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
 | ||||
|                 // We do this because UBOs require an element stride that is 16-byte
 | ||||
|                 // aligned.
 | ||||
|                 auto* buffer_size_struct = b.Structure( | ||||
|                     b.Sym(), utils::Vector{ | ||||
|                                  b.Member(kBufferSizeMemberName, | ||||
|                                           b.ty.array(b.ty.vec4(b.ty.u32()), | ||||
|                                                      u32((max_buffer_size_index / 4) + 1))), | ||||
|                              }); | ||||
|                 buffer_size_ubo = | ||||
|                     b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct), ast::AddressSpace::kUniform, | ||||
|                                 b.Group(AInt(cfg->ubo_binding.group)), | ||||
|                                 b.Binding(AInt(cfg->ubo_binding.binding))); | ||||
|             } | ||||
|             return buffer_size_ubo; | ||||
|         }; | ||||
| 
 | ||||
|         std::unordered_set<uint32_t> used_size_indices; | ||||
| 
 | ||||
|         IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, | ||||
|                                            const sem::VariableUser* storage_buffer_sem, | ||||
|                                            const sem::GlobalVariable* var) { | ||||
|             auto binding = var->BindingPoint(); | ||||
|             auto idx_itr = cfg->bindpoint_to_size_index.find(binding); | ||||
|             if (idx_itr == cfg->bindpoint_to_size_index.end()) { | ||||
|                 return; | ||||
|             } | ||||
| 
 | ||||
|             uint32_t size_index = idx_itr->second; | ||||
|             used_size_indices.insert(size_index); | ||||
| 
 | ||||
|             // Load the total storage buffer size from the UBO.
 | ||||
|             uint32_t array_index = size_index / 4; | ||||
|             auto* vec_expr = b.IndexAccessor( | ||||
|                 b.MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index)); | ||||
|             uint32_t vec_index = size_index % 4; | ||||
|             auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index)); | ||||
| 
 | ||||
|             // Calculate actual array length
 | ||||
|             //                total_storage_buffer_size - array_offset
 | ||||
|             // array_length = ----------------------------------------
 | ||||
|             //                             array_stride
 | ||||
|             const ast::Expression* total_size = total_storage_buffer_size; | ||||
|             auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); | ||||
|             const sem::Array* array_type = nullptr; | ||||
|             if (auto* str = storage_buffer_type->As<sem::Struct>()) { | ||||
|                 // The variable is a struct, so subtract the byte offset of the array
 | ||||
|                 // member.
 | ||||
|                 auto* array_member_sem = str->Members().back(); | ||||
|                 array_type = array_member_sem->Type()->As<sem::Array>(); | ||||
|                 total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset())); | ||||
|             } else if (auto* arr = storage_buffer_type->As<sem::Array>()) { | ||||
|                 array_type = arr; | ||||
|             } else { | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "expected form of arrayLength argument to be &array_var or " | ||||
|                        "&struct_var.array_member"; | ||||
|                 return; | ||||
|             } | ||||
|             auto* array_length = b.Div(total_size, u32(array_type->Stride())); | ||||
| 
 | ||||
|             ctx.Replace(call_expr, array_length); | ||||
|         }); | ||||
| 
 | ||||
|         outputs.Add<Result>(used_size_indices); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The transform inputs
 | ||||
|     const DataMap& inputs; | ||||
|     /// The transform outputs
 | ||||
|     DataMap& outputs; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     /// Iterate over all arrayLength() builtins that operate on
 | ||||
|     /// storage buffer variables.
 | ||||
| @ -48,10 +188,10 @@ struct ArrayLengthFromUniform::State { | ||||
|     /// sem::GlobalVariable for the storage buffer.
 | ||||
|     template <typename F> | ||||
|     void IterateArrayLengthOnStorageVar(F&& functor) { | ||||
|         auto& sem = ctx.src->Sem(); | ||||
|         auto& sem = src->Sem(); | ||||
| 
 | ||||
|         // Find all calls to the arrayLength() builtin.
 | ||||
|         for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|         for (auto* node : src->ASTNodes().Objects()) { | ||||
|             auto* call_expr = node->As<ast::CallExpression>(); | ||||
|             if (!call_expr) { | ||||
|                 continue; | ||||
| @ -79,7 +219,7 @@ struct ArrayLengthFromUniform::State { | ||||
|             //   arrayLength(&array_var)
 | ||||
|             auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>(); | ||||
|             if (!param || param->op != ast::UnaryOp::kAddressOf) { | ||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "expected form of arrayLength argument to be &array_var or " | ||||
|                        "&struct_var.array_member"; | ||||
|                 break; | ||||
| @ -90,7 +230,7 @@ struct ArrayLengthFromUniform::State { | ||||
|             } | ||||
|             auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); | ||||
|             if (!storage_buffer_sem) { | ||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "expected form of arrayLength argument to be &array_var or " | ||||
|                        "&struct_var.array_member"; | ||||
|                 break; | ||||
| @ -99,8 +239,7 @@ struct ArrayLengthFromUniform::State { | ||||
|             // Get the index to use for the buffer size array.
 | ||||
|             auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable()); | ||||
|             if (!var) { | ||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                     << "storage buffer is not a global variable"; | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable"; | ||||
|                 break; | ||||
|             } | ||||
|             functor(call_expr, storage_buffer_sem, var); | ||||
| @ -108,117 +247,10 @@ struct ArrayLengthFromUniform::State { | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (auto* sem_fn = program->Sem().Get(fn)) { | ||||
|             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { | ||||
|                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { | ||||
|                     return true; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { | ||||
|     auto* cfg = inputs.Get<Config>(); | ||||
|     if (cfg == nullptr) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     const char* kBufferSizeMemberName = "buffer_size"; | ||||
| 
 | ||||
|     // Determine the size of the buffer size array.
 | ||||
|     uint32_t max_buffer_size_index = 0; | ||||
| 
 | ||||
|     State{ctx}.IterateArrayLengthOnStorageVar( | ||||
|         [&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) { | ||||
|             auto binding = var->BindingPoint(); | ||||
|             auto idx_itr = cfg->bindpoint_to_size_index.find(binding); | ||||
|             if (idx_itr == cfg->bindpoint_to_size_index.end()) { | ||||
|                 return; | ||||
|             } | ||||
|             if (idx_itr->second > max_buffer_size_index) { | ||||
|                 max_buffer_size_index = idx_itr->second; | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|     // Get (or create, on first call) the uniform buffer that will receive the
 | ||||
|     // size of each storage buffer in the module.
 | ||||
|     const ast::Variable* buffer_size_ubo = nullptr; | ||||
|     auto get_ubo = [&]() { | ||||
|         if (!buffer_size_ubo) { | ||||
|             // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
 | ||||
|             // We do this because UBOs require an element stride that is 16-byte
 | ||||
|             // aligned.
 | ||||
|             auto* buffer_size_struct = ctx.dst->Structure( | ||||
|                 ctx.dst->Sym(), | ||||
|                 utils::Vector{ | ||||
|                     ctx.dst->Member(kBufferSizeMemberName, | ||||
|                                     ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), | ||||
|                                                       u32((max_buffer_size_index / 4) + 1))), | ||||
|                 }); | ||||
|             buffer_size_ubo = ctx.dst->GlobalVar(ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), | ||||
|                                                  ast::AddressSpace::kUniform, | ||||
|                                                  ctx.dst->Group(AInt(cfg->ubo_binding.group)), | ||||
|                                                  ctx.dst->Binding(AInt(cfg->ubo_binding.binding))); | ||||
|         } | ||||
|         return buffer_size_ubo; | ||||
|     }; | ||||
| 
 | ||||
|     std::unordered_set<uint32_t> used_size_indices; | ||||
| 
 | ||||
|     State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, | ||||
|                                                   const sem::VariableUser* storage_buffer_sem, | ||||
|                                                   const sem::GlobalVariable* var) { | ||||
|         auto binding = var->BindingPoint(); | ||||
|         auto idx_itr = cfg->bindpoint_to_size_index.find(binding); | ||||
|         if (idx_itr == cfg->bindpoint_to_size_index.end()) { | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
|         uint32_t size_index = idx_itr->second; | ||||
|         used_size_indices.insert(size_index); | ||||
| 
 | ||||
|         // Load the total storage buffer size from the UBO.
 | ||||
|         uint32_t array_index = size_index / 4; | ||||
|         auto* vec_expr = ctx.dst->IndexAccessor( | ||||
|             ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index)); | ||||
|         uint32_t vec_index = size_index % 4; | ||||
|         auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, u32(vec_index)); | ||||
| 
 | ||||
|         // Calculate actual array length
 | ||||
|         //                total_storage_buffer_size - array_offset
 | ||||
|         // array_length = ----------------------------------------
 | ||||
|         //                             array_stride
 | ||||
|         const ast::Expression* total_size = total_storage_buffer_size; | ||||
|         auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); | ||||
|         const sem::Array* array_type = nullptr; | ||||
|         if (auto* str = storage_buffer_type->As<sem::Struct>()) { | ||||
|             // The variable is a struct, so subtract the byte offset of the array
 | ||||
|             // member.
 | ||||
|             auto* array_member_sem = str->Members().back(); | ||||
|             array_type = array_member_sem->Type()->As<sem::Array>(); | ||||
|             total_size = ctx.dst->Sub(total_storage_buffer_size, u32(array_member_sem->Offset())); | ||||
|         } else if (auto* arr = storage_buffer_type->As<sem::Array>()) { | ||||
|             array_type = arr; | ||||
|         } else { | ||||
|             TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                 << "expected form of arrayLength argument to be &array_var or " | ||||
|                    "&struct_var.array_member"; | ||||
|             return; | ||||
|         } | ||||
|         auto* array_length = ctx.dst->Div(total_size, u32(array_type->Stride())); | ||||
| 
 | ||||
|         ctx.Replace(call_expr, array_length); | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
| 
 | ||||
|     outputs.Add<Result>(used_size_indices); | ||||
| Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src, | ||||
|                                                      const DataMap& inputs, | ||||
|                                                      DataMap& outputs) const { | ||||
|     return State{src, inputs, outputs}.Run(); | ||||
| } | ||||
| 
 | ||||
| ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {} | ||||
|  | ||||
| @ -100,22 +100,12 @@ class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Tra | ||||
|         std::unordered_set<uint32_t> used_size_indices; | ||||
|     }; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     /// The PIMPL state for this transform
 | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -28,7 +28,13 @@ using ArrayLengthFromUniformTest = TransformTest; | ||||
| TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) { | ||||
|     auto* src = R"()"; | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src)); | ||||
|     ArrayLengthFromUniform::Config cfg({0, 30u}); | ||||
|     cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); | ||||
| 
 | ||||
|     DataMap data; | ||||
|     data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) { | ||||
| @ -45,7 +51,13 @@ fn main() { | ||||
| } | ||||
| )"; | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src)); | ||||
|     ArrayLengthFromUniform::Config cfg({0, 30u}); | ||||
|     cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); | ||||
| 
 | ||||
|     DataMap data; | ||||
|     data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) { | ||||
| @ -63,7 +75,13 @@ fn main() { | ||||
| } | ||||
| )"; | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src)); | ||||
|     ArrayLengthFromUniform::Config cfg({0, 30u}); | ||||
|     cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); | ||||
| 
 | ||||
|     DataMap data; | ||||
|     data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { | ||||
|  | ||||
| @ -40,19 +40,21 @@ BindingRemapper::Remappings::~Remappings() = default; | ||||
| BindingRemapper::BindingRemapper() = default; | ||||
| BindingRemapper::~BindingRemapper() = default; | ||||
| 
 | ||||
| bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const { | ||||
|     if (auto* remappings = inputs.Get<Remappings>()) { | ||||
|         return !remappings->binding_points.empty() || !remappings->access_controls.empty(); | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| Transform::ApplyResult BindingRemapper::Apply(const Program* src, | ||||
|                                               const DataMap& inputs, | ||||
|                                               DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
| void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
|     auto* remappings = inputs.Get<Remappings>(); | ||||
|     if (!remappings) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     if (remappings->binding_points.empty() && remappings->access_controls.empty()) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     // A set of post-remapped binding points that need to be decorated with a
 | ||||
| @ -62,11 +64,11 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | ||||
|     if (remappings->allow_collisions) { | ||||
|         // Scan for binding point collisions generated by this transform.
 | ||||
|         // Populate all collisions in the `add_collision_attr` set.
 | ||||
|         for (auto* func_ast : ctx.src->AST().Functions()) { | ||||
|         for (auto* func_ast : src->AST().Functions()) { | ||||
|             if (!func_ast->IsEntryPoint()) { | ||||
|                 continue; | ||||
|             } | ||||
|             auto* func = ctx.src->Sem().Get(func_ast); | ||||
|             auto* func = src->Sem().Get(func_ast); | ||||
|             std::unordered_map<sem::BindingPoint, int> binding_point_counts; | ||||
|             for (auto* global : func->TransitivelyReferencedGlobals()) { | ||||
|                 if (global->Declaration()->HasBindingPoint()) { | ||||
| @ -90,9 +92,9 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     for (auto* var : ctx.src->AST().Globals<ast::Var>()) { | ||||
|     for (auto* var : src->AST().Globals<ast::Var>()) { | ||||
|         if (var->HasBindingPoint()) { | ||||
|             auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(var); | ||||
|             auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var); | ||||
| 
 | ||||
|             // The original binding point
 | ||||
|             BindingPoint from = global_sem->BindingPoint(); | ||||
| @ -106,8 +108,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | ||||
|             auto bp_it = remappings->binding_points.find(from); | ||||
|             if (bp_it != remappings->binding_points.end()) { | ||||
|                 BindingPoint to = bp_it->second; | ||||
|                 auto* new_group = ctx.dst->Group(AInt(to.group)); | ||||
|                 auto* new_binding = ctx.dst->Binding(AInt(to.binding)); | ||||
|                 auto* new_group = b.Group(AInt(to.group)); | ||||
|                 auto* new_binding = b.Binding(AInt(to.binding)); | ||||
| 
 | ||||
|                 auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes); | ||||
|                 auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes); | ||||
| @ -122,37 +124,37 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co | ||||
|             if (ac_it != remappings->access_controls.end()) { | ||||
|                 ast::Access ac = ac_it->second; | ||||
|                 if (ac == ast::Access::kUndefined) { | ||||
|                     ctx.dst->Diagnostics().add_error( | ||||
|                     b.Diagnostics().add_error( | ||||
|                         diag::System::Transform, | ||||
|                         "invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")"); | ||||
|                     return; | ||||
|                     return Program(std::move(b)); | ||||
|                 } | ||||
|                 auto* sem = ctx.src->Sem().Get(var); | ||||
|                 auto* sem = src->Sem().Get(var); | ||||
|                 if (sem->AddressSpace() != ast::AddressSpace::kStorage) { | ||||
|                     ctx.dst->Diagnostics().add_error( | ||||
|                     b.Diagnostics().add_error( | ||||
|                         diag::System::Transform, | ||||
|                         "cannot apply access control to variable with address space " + | ||||
|                             std::string(utils::ToString(sem->AddressSpace()))); | ||||
|                     return; | ||||
|                     return Program(std::move(b)); | ||||
|                 } | ||||
|                 auto* ty = sem->Type()->UnwrapRef(); | ||||
|                 const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty); | ||||
|                 auto* new_var = | ||||
|                     ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, | ||||
|                                  var->declared_address_space, ac, ctx.Clone(var->initializer), | ||||
|                                  ctx.Clone(var->attributes)); | ||||
|                 auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, | ||||
|                                       var->declared_address_space, ac, ctx.Clone(var->initializer), | ||||
|                                       ctx.Clone(var->attributes)); | ||||
|                 ctx.Replace(var, new_var); | ||||
|             } | ||||
| 
 | ||||
|             // Add `DisableValidationAttribute`s if required
 | ||||
|             if (add_collision_attr.count(bp)) { | ||||
|                 auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision); | ||||
|                 auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision); | ||||
|                 ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -67,19 +67,10 @@ class BindingRemapper final : public Castable<BindingRemapper, Transform> { | ||||
|     BindingRemapper(); | ||||
|     ~BindingRemapper() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -23,12 +23,6 @@ namespace { | ||||
| 
 | ||||
| using BindingRemapperTest = TransformTest; | ||||
| 
 | ||||
| TEST_F(BindingRemapperTest, ShouldRunNoRemappings) { | ||||
|     auto* src = R"()"; | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<BindingRemapper>(src)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) { | ||||
|     auto* src = R"()"; | ||||
| 
 | ||||
| @ -350,7 +344,7 @@ fn f() { | ||||
| } | ||||
| )"; | ||||
| 
 | ||||
|     auto* expect = src; | ||||
|     auto* expect = R"(error: missing transform data for tint::transform::BindingRemapper)"; | ||||
| 
 | ||||
|     auto got = Run<BindingRemapper>(src); | ||||
| 
 | ||||
|  | ||||
| @ -29,7 +29,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::BuiltinPolyfill::Config); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// The PIMPL state for the BuiltinPolyfill transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct BuiltinPolyfill::State { | ||||
|     /// Constructor
 | ||||
|     /// @param c the CloneContext
 | ||||
| @ -604,193 +604,100 @@ BuiltinPolyfill::BuiltinPolyfill() = default; | ||||
| 
 | ||||
| BuiltinPolyfill::~BuiltinPolyfill() = default; | ||||
| 
 | ||||
| bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const { | ||||
|     if (auto* cfg = data.Get<Config>()) { | ||||
|         auto builtins = cfg->builtins; | ||||
|         auto& sem = program->Sem(); | ||||
|         for (auto* node : program->ASTNodes().Objects()) { | ||||
|             if (auto* call = sem.Get<sem::Call>(node)) { | ||||
|                 if (auto* builtin = call->Target()->As<sem::Builtin>()) { | ||||
|                     if (call->Stage() == sem::EvaluationStage::kConstant) { | ||||
|                         continue;  // Don't polyfill @const expressions
 | ||||
|                     } | ||||
|                     switch (builtin->Type()) { | ||||
|                         case sem::BuiltinType::kAcosh: | ||||
|                             if (builtins.acosh != Level::kNone) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kAsinh: | ||||
|                             if (builtins.asinh) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kAtanh: | ||||
|                             if (builtins.atanh != Level::kNone) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kClamp: | ||||
|                             if (builtins.clamp_int) { | ||||
|                                 auto& sig = builtin->Signature(); | ||||
|                                 return sig.parameters[0]->Type()->is_integer_scalar_or_vector(); | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kCountLeadingZeros: | ||||
|                             if (builtins.count_leading_zeros) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kCountTrailingZeros: | ||||
|                             if (builtins.count_trailing_zeros) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kExtractBits: | ||||
|                             if (builtins.extract_bits != Level::kNone) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kFirstLeadingBit: | ||||
|                             if (builtins.first_leading_bit) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kFirstTrailingBit: | ||||
|                             if (builtins.first_trailing_bit) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kInsertBits: | ||||
|                             if (builtins.insert_bits != Level::kNone) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kSaturate: | ||||
|                             if (builtins.saturate) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kTextureSampleBaseClampToEdge: | ||||
|                             if (builtins.texture_sample_base_clamp_to_edge_2d_f32) { | ||||
|                                 auto& sig = builtin->Signature(); | ||||
|                                 auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); | ||||
|                                 if (auto* stex = tex->Type()->As<sem::SampledTexture>()) { | ||||
|                                     return stex->type()->Is<sem::F32>(); | ||||
|                                 } | ||||
|                             } | ||||
|                             break; | ||||
|                         case sem::BuiltinType::kQuantizeToF16: | ||||
|                             if (builtins.quantize_to_vec_f16) { | ||||
|                                 if (builtin->ReturnType()->Is<sem::Vector>()) { | ||||
|                                     return true; | ||||
|                                 } | ||||
|                             } | ||||
|                             break; | ||||
|                         default: | ||||
|                             break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const { | ||||
| Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src, | ||||
|                                               const DataMap& data, | ||||
|                                               DataMap&) const { | ||||
|     auto* cfg = data.Get<Config>(); | ||||
|     if (!cfg) { | ||||
|         ctx.Clone(); | ||||
|         return; | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     std::unordered_map<const sem::Builtin*, Symbol> polyfills; | ||||
|     auto& builtins = cfg->builtins; | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { | ||||
|         auto builtins = cfg->builtins; | ||||
|         State s{ctx, builtins}; | ||||
|         if (auto* call = s.sem.Get<sem::Call>(expr)) { | ||||
|     utils::Hashmap<const sem::Builtin*, Symbol, 8> polyfills; | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     State s{ctx, builtins}; | ||||
| 
 | ||||
|     bool made_changes = false; | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* call = src->Sem().Get<sem::Call>(node)) { | ||||
|             if (auto* builtin = call->Target()->As<sem::Builtin>()) { | ||||
|                 if (call->Stage() == sem::EvaluationStage::kConstant) { | ||||
|                     return nullptr;  // Don't polyfill @const expressions
 | ||||
|                     continue;  // Don't polyfill @const expressions
 | ||||
|                 } | ||||
|                 Symbol polyfill; | ||||
|                 switch (builtin->Type()) { | ||||
|                     case sem::BuiltinType::kAcosh: | ||||
|                         if (builtins.acosh != Level::kNone) { | ||||
|                             polyfill = utils::GetOrCreate( | ||||
|                                 polyfills, builtin, [&] { return s.acosh(builtin->ReturnType()); }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.acosh(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kAsinh: | ||||
|                         if (builtins.asinh) { | ||||
|                             polyfill = utils::GetOrCreate( | ||||
|                                 polyfills, builtin, [&] { return s.asinh(builtin->ReturnType()); }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.asinh(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kAtanh: | ||||
|                         if (builtins.atanh != Level::kNone) { | ||||
|                             polyfill = utils::GetOrCreate( | ||||
|                                 polyfills, builtin, [&] { return s.atanh(builtin->ReturnType()); }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.atanh(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kClamp: | ||||
|                         if (builtins.clamp_int) { | ||||
|                             auto& sig = builtin->Signature(); | ||||
|                             if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) { | ||||
|                                 polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                     return s.clampInteger(builtin->ReturnType()); | ||||
|                                 }); | ||||
|                                 polyfill = polyfills.GetOrCreate( | ||||
|                                     builtin, [&] { return s.clampInteger(builtin->ReturnType()); }); | ||||
|                             } | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kCountLeadingZeros: | ||||
|                         if (builtins.count_leading_zeros) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                             polyfill = polyfills.GetOrCreate(builtin, [&] { | ||||
|                                 return s.countLeadingZeros(builtin->ReturnType()); | ||||
|                             }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kCountTrailingZeros: | ||||
|                         if (builtins.count_trailing_zeros) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                             polyfill = polyfills.GetOrCreate(builtin, [&] { | ||||
|                                 return s.countTrailingZeros(builtin->ReturnType()); | ||||
|                             }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kExtractBits: | ||||
|                         if (builtins.extract_bits != Level::kNone) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                 return s.extractBits(builtin->ReturnType()); | ||||
|                             }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.extractBits(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kFirstLeadingBit: | ||||
|                         if (builtins.first_leading_bit) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                 return s.firstLeadingBit(builtin->ReturnType()); | ||||
|                             }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kFirstTrailingBit: | ||||
|                         if (builtins.first_trailing_bit) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                 return s.firstTrailingBit(builtin->ReturnType()); | ||||
|                             }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kInsertBits: | ||||
|                         if (builtins.insert_bits != Level::kNone) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                 return s.insertBits(builtin->ReturnType()); | ||||
|                             }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.insertBits(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kSaturate: | ||||
|                         if (builtins.saturate) { | ||||
|                             polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                 return s.saturate(builtin->ReturnType()); | ||||
|                             }); | ||||
|                             polyfill = polyfills.GetOrCreate( | ||||
|                                 builtin, [&] { return s.saturate(builtin->ReturnType()); }); | ||||
|                         } | ||||
|                         break; | ||||
|                     case sem::BuiltinType::kTextureSampleBaseClampToEdge: | ||||
| @ -799,7 +706,7 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons | ||||
|                             auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); | ||||
|                             if (auto* stex = tex->Type()->As<sem::SampledTexture>()) { | ||||
|                                 if (stex->type()->Is<sem::F32>()) { | ||||
|                                     polyfill = utils::GetOrCreate(polyfills, builtin, [&] { | ||||
|                                     polyfill = polyfills.GetOrCreate(builtin, [&] { | ||||
|                                         return s.textureSampleBaseClampToEdge_2d_f32(); | ||||
|                                     }); | ||||
|                                 } | ||||
| @ -809,8 +716,8 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons | ||||
|                     case sem::BuiltinType::kQuantizeToF16: | ||||
|                         if (builtins.quantize_to_vec_f16) { | ||||
|                             if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) { | ||||
|                                 polyfill = utils::GetOrCreate(polyfills, builtin, | ||||
|                                                               [&] { return s.quantizeToF16(vec); }); | ||||
|                                 polyfill = polyfills.GetOrCreate( | ||||
|                                     builtin, [&] { return s.quantizeToF16(vec); }); | ||||
|                             } | ||||
|                         } | ||||
|                         break; | ||||
| @ -819,14 +726,20 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons | ||||
|                         break; | ||||
|                 } | ||||
|                 if (polyfill.IsValid()) { | ||||
|                     return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args)); | ||||
|                     auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args)); | ||||
|                     ctx.Replace(call->Declaration(), replacement); | ||||
|                     made_changes = true; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return nullptr; | ||||
|     }); | ||||
|     } | ||||
| 
 | ||||
|     if (!made_changes) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {} | ||||
|  | ||||
| @ -87,21 +87,13 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> { | ||||
|         const Builtins builtins; | ||||
|     }; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   protected: | ||||
|   private: | ||||
|     struct State; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -1561,7 +1561,8 @@ fn f() { | ||||
| TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) { | ||||
|     auto* src = R"( | ||||
| fn f() { | ||||
|   let r : i32 = insertBits(1234, 5678, 5u, 6u); | ||||
|   let v = 1234i; | ||||
|   let r : i32 = insertBits(v, 5678, 5u, 6u); | ||||
| } | ||||
| )"; | ||||
| 
 | ||||
| @ -1975,10 +1976,6 @@ fn f() { | ||||
| )"; | ||||
| 
 | ||||
|     auto* expect = R"( | ||||
| @group(0) @binding(0) var t : texture_2d<f32>; | ||||
| 
 | ||||
| @group(0) @binding(1) var s : sampler; | ||||
| 
 | ||||
| fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> { | ||||
|   let dims = vec2<f32>(textureDimensions(t, 0)); | ||||
|   let half_texel = (vec2<f32>(0.5) / dims); | ||||
| @ -1986,6 +1983,10 @@ fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : v | ||||
|   return textureSampleLevel(t, s, clamped, 0); | ||||
| } | ||||
| 
 | ||||
| @group(0) @binding(0) var t : texture_2d<f32>; | ||||
| 
 | ||||
| @group(0) @binding(1) var s : sampler; | ||||
| 
 | ||||
| fn f() { | ||||
|   let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5)); | ||||
| } | ||||
|  | ||||
| @ -40,6 +40,19 @@ namespace tint::transform { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (auto* sem_fn = program->Sem().Get(fn)) { | ||||
|             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { | ||||
|                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { | ||||
|                     return true; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| /// ArrayUsage describes a runtime array usage.
 | ||||
| /// It is used as a key by the array_length_by_usage map.
 | ||||
| struct ArrayUsage { | ||||
| @ -73,21 +86,16 @@ const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSiz | ||||
| CalculateArrayLength::CalculateArrayLength() = default; | ||||
| CalculateArrayLength::~CalculateArrayLength() = default; | ||||
| 
 | ||||
| bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (auto* sem_fn = program->Sem().Get(fn)) { | ||||
|             for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { | ||||
|                 if (builtin->Type() == sem::BuiltinType::kArrayLength) { | ||||
|                     return true; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| Transform::ApplyResult CalculateArrayLength::Apply(const Program* src, | ||||
|                                                    const DataMap&, | ||||
|                                                    DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     auto& sem = ctx.src->Sem(); | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     auto& sem = src->Sem(); | ||||
| 
 | ||||
|     // get_buffer_size_intrinsic() emits the function decorated with
 | ||||
|     // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
 | ||||
| @ -95,24 +103,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|     std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics; | ||||
|     auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) { | ||||
|         return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { | ||||
|             auto name = ctx.dst->Sym(); | ||||
|             auto name = b.Sym(); | ||||
|             auto* type = CreateASTTypeFor(ctx, buffer_type); | ||||
|             auto* disable_validation = | ||||
|                 ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter); | ||||
|             ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>( | ||||
|             auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter); | ||||
|             b.AST().AddFunction(b.create<ast::Function>( | ||||
|                 name, | ||||
|                 utils::Vector{ | ||||
|                     ctx.dst->Param("buffer", | ||||
|                                    ctx.dst->ty.pointer(type, buffer_type->AddressSpace(), | ||||
|                                                        buffer_type->Access()), | ||||
|                                    utils::Vector{disable_validation}), | ||||
|                     ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(), | ||||
|                                                                  ast::AddressSpace::kFunction)), | ||||
|                     b.Param("buffer", | ||||
|                             b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()), | ||||
|                             utils::Vector{disable_validation}), | ||||
|                     b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)), | ||||
|                 }, | ||||
|                 ctx.dst->ty.void_(), nullptr, | ||||
|                 b.ty.void_(), nullptr, | ||||
|                 utils::Vector{ | ||||
|                     ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID(), | ||||
|                                                                     ctx.dst->AllocateNodeID()), | ||||
|                     b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()), | ||||
|                 }, | ||||
|                 utils::Empty)); | ||||
| 
 | ||||
| @ -123,7 +127,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|     std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage; | ||||
| 
 | ||||
|     // Find all the arrayLength() calls...
 | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* call_expr = node->As<ast::CallExpression>()) { | ||||
|             auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|             if (auto* builtin = call->Target()->As<sem::Builtin>()) { | ||||
| @ -149,7 +153,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|                     auto* arg = call_expr->args[0]; | ||||
|                     auto* address_of = arg->As<ast::UnaryOpExpression>(); | ||||
|                     if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { | ||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                         TINT_ICE(Transform, b.Diagnostics()) | ||||
|                             << "arrayLength() expected address-of, got " << arg->TypeInfo().name; | ||||
|                     } | ||||
|                     auto* storage_buffer_expr = address_of->expr; | ||||
| @ -158,7 +162,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|                     } | ||||
|                     auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); | ||||
|                     if (!storage_buffer_sem) { | ||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                         TINT_ICE(Transform, b.Diagnostics()) | ||||
|                             << "expected form of arrayLength argument to be &array_var or " | ||||
|                                "&struct_var.array_member"; | ||||
|                         break; | ||||
| @ -179,25 +183,24 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
| 
 | ||||
|                             // Construct the variable that'll hold the result of
 | ||||
|                             // RWByteAddressBuffer.GetDimensions()
 | ||||
|                             auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var( | ||||
|                                 ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u))); | ||||
|                             auto* buffer_size_result = | ||||
|                                 b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u))); | ||||
| 
 | ||||
|                             // Call storage_buffer.GetDimensions(&buffer_size_result)
 | ||||
|                             auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call( | ||||
|                             auto* call_get_dims = b.CallStmt(b.Call( | ||||
|                                 // BufferSizeIntrinsic(X, ARGS...) is
 | ||||
|                                 // translated to:
 | ||||
|                                 //  X.GetDimensions(ARGS..) by the writer
 | ||||
|                                 buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)), | ||||
|                                 ctx.dst->AddressOf( | ||||
|                                     ctx.dst->Expr(buffer_size_result->variable->symbol)))); | ||||
|                                 buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)), | ||||
|                                 b.AddressOf(b.Expr(buffer_size_result->variable->symbol)))); | ||||
| 
 | ||||
|                             // Calculate actual array length
 | ||||
|                             //                total_storage_buffer_size - array_offset
 | ||||
|                             // array_length = ----------------------------------------
 | ||||
|                             //                             array_stride
 | ||||
|                             auto name = ctx.dst->Sym(); | ||||
|                             auto name = b.Sym(); | ||||
|                             const ast::Expression* total_size = | ||||
|                                 ctx.dst->Expr(buffer_size_result->variable); | ||||
|                                 b.Expr(buffer_size_result->variable); | ||||
| 
 | ||||
|                             const sem::Array* array_type = Switch( | ||||
|                                 storage_buffer_type->StoreType(), | ||||
| @ -205,23 +208,21 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|                                     // The variable is a struct, so subtract the byte offset of
 | ||||
|                                     // the array member.
 | ||||
|                                     auto* array_member_sem = str->Members().back(); | ||||
|                                     total_size = | ||||
|                                         ctx.dst->Sub(total_size, u32(array_member_sem->Offset())); | ||||
|                                     total_size = b.Sub(total_size, u32(array_member_sem->Offset())); | ||||
|                                     return array_member_sem->Type()->As<sem::Array>(); | ||||
|                                 }, | ||||
|                                 [&](const sem::Array* arr) { return arr; }); | ||||
| 
 | ||||
|                             if (!array_type) { | ||||
|                                 TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                                     << "expected form of arrayLength argument to be " | ||||
|                                        "&array_var or &struct_var.array_member"; | ||||
|                                 return name; | ||||
|                             } | ||||
| 
 | ||||
|                             uint32_t array_stride = array_type->Size(); | ||||
|                             auto* array_length_var = ctx.dst->Decl( | ||||
|                                 ctx.dst->Let(name, ctx.dst->ty.u32(), | ||||
|                                              ctx.dst->Div(total_size, u32(array_stride)))); | ||||
|                             auto* array_length_var = b.Decl( | ||||
|                                 b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride)))); | ||||
| 
 | ||||
|                             // Insert the array length calculations at the top of the block
 | ||||
|                             ctx.InsertBefore(block->statements, block->statements[0], | ||||
| @ -234,13 +235,14 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|                         }); | ||||
| 
 | ||||
|                     // Replace the call to arrayLength() with the array length variable
 | ||||
|                     ctx.Replace(call_expr, ctx.dst->Expr(array_length)); | ||||
|                     ctx.Replace(call_expr, b.Expr(array_length)); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -59,19 +59,10 @@ class CalculateArrayLength final : public Castable<CalculateArrayLength, Transfo | ||||
|     /// Destructor
 | ||||
|     ~CalculateArrayLength() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -123,7 +123,7 @@ bool HasSampleMask(utils::VectorRef<const ast::Attribute*> attrs) { | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// State holds the current transform state for a single entry point.
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct CanonicalizeEntryPointIO::State { | ||||
|     /// OutputValue represents a shader result that the wrapper function produces.
 | ||||
|     struct OutputValue { | ||||
| @ -770,17 +770,22 @@ struct CanonicalizeEntryPointIO::State { | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
| Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src, | ||||
|                                                        const DataMap& inputs, | ||||
|                                                        DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     auto* cfg = inputs.Get<Config>(); | ||||
|     if (cfg == nullptr) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     // Remove entry point IO attributes from struct declarations.
 | ||||
|     // New structures will be created for each entry point, as necessary.
 | ||||
|     for (auto* ty : ctx.src->AST().TypeDecls()) { | ||||
|     for (auto* ty : src->AST().TypeDecls()) { | ||||
|         if (auto* struct_ty = ty->As<ast::Struct>()) { | ||||
|             for (auto* member : struct_ty->members) { | ||||
|                 for (auto* attr : member->attributes) { | ||||
| @ -792,7 +797,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     for (auto* func_ast : ctx.src->AST().Functions()) { | ||||
|     for (auto* func_ast : src->AST().Functions()) { | ||||
|         if (!func_ast->IsEntryPoint()) { | ||||
|             continue; | ||||
|         } | ||||
| @ -802,6 +807,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| CanonicalizeEntryPointIO::Config::Config(ShaderStyle style, | ||||
|  | ||||
| @ -127,15 +127,12 @@ class CanonicalizeEntryPointIO final : public Castable<CanonicalizeEntryPointIO, | ||||
|     CanonicalizeEntryPointIO(); | ||||
|     ~CanonicalizeEntryPointIO() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -14,7 +14,7 @@ | ||||
| 
 | ||||
| #include "src/tint/transform/clamp_frag_depth.h" | ||||
| 
 | ||||
|  #include <utility> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/ast/attribute.h" | ||||
| #include "src/tint/ast/builtin_attribute.h" | ||||
| @ -64,12 +64,7 @@ bool ReturnsFragDepthInStruct(const sem::Info& sem, const ast::Function* fn) { | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| }  // anonymous namespace
 | ||||
| 
 | ||||
| ClampFragDepth::ClampFragDepth() = default; | ||||
| ClampFragDepth::~ClampFragDepth() = default; | ||||
| 
 | ||||
| bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     auto& sem = program->Sem(); | ||||
| 
 | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
| @ -82,22 +77,33 @@ bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| }  // anonymous namespace
 | ||||
| 
 | ||||
| ClampFragDepth::ClampFragDepth() = default; | ||||
| ClampFragDepth::~ClampFragDepth() = default; | ||||
| 
 | ||||
| Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     // Abort on any use of push constants in the module.
 | ||||
|     for (auto* global : ctx.src->AST().GlobalVariables()) { | ||||
|     for (auto* global : src->AST().GlobalVariables()) { | ||||
|         if (auto* var = global->As<ast::Var>()) { | ||||
|             if (var->declared_address_space == ast::AddressSpace::kPushConstant) { | ||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "ClampFragDepth doesn't know how to handle module that already use push " | ||||
|                        "constants."; | ||||
|                 return; | ||||
|                 return Program(std::move(b)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     auto& b = *ctx.dst; | ||||
|     auto& sem = ctx.src->Sem(); | ||||
|     auto& sym = ctx.src->Symbols(); | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     auto& sem = src->Sem(); | ||||
|     auto& sym = src->Symbols(); | ||||
| 
 | ||||
|     // At least one entry-point needs clamping. Add the following to the module:
 | ||||
|     //
 | ||||
| @ -197,6 +203,7 @@ void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -61,19 +61,10 @@ class ClampFragDepth final : public Castable<ClampFragDepth, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~ClampFragDepth() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -47,10 +47,14 @@ CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map, | ||||
| CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default; | ||||
| CombineSamplers::BindingInfo::~BindingInfo() = default; | ||||
| 
 | ||||
| /// The PIMPL state for the CombineSamplers transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct CombineSamplers::State { | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     /// The binding info
 | ||||
|     const BindingInfo* binding_info; | ||||
| @ -88,9 +92,9 @@ struct CombineSamplers::State { | ||||
|     } | ||||
| 
 | ||||
|     /// Constructor
 | ||||
|     /// @param context the clone context
 | ||||
|     /// @param program the source program
 | ||||
|     /// @param info the binding map information
 | ||||
|     State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {} | ||||
|     State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {} | ||||
| 
 | ||||
|     /// Creates a combined sampler global variables.
 | ||||
|     /// (Note this is actually a Texture node at the AST level, but it will be
 | ||||
| @ -145,8 +149,9 @@ struct CombineSamplers::State { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Performs the transformation
 | ||||
|     void Run() { | ||||
|     /// Runs the transform
 | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         auto& sem = ctx.src->Sem(); | ||||
| 
 | ||||
|         // Remove all texture and sampler global variables. These will be replaced
 | ||||
| @ -169,14 +174,14 @@ struct CombineSamplers::State { | ||||
| 
 | ||||
|         // Rewrite all function signatures to use combined samplers, and remove
 | ||||
|         // separate textures & samplers. Create new combined globals where found.
 | ||||
|         ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* { | ||||
|             if (auto* func = sem.Get(src)) { | ||||
|                 auto pairs = func->TextureSamplerPairs(); | ||||
|         ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* { | ||||
|             if (auto* fn = sem.Get(ast_fn)) { | ||||
|                 auto pairs = fn->TextureSamplerPairs(); | ||||
|                 if (pairs.IsEmpty()) { | ||||
|                     return nullptr; | ||||
|                 } | ||||
|                 utils::Vector<const ast::Parameter*, 8> params; | ||||
|                 for (auto pair : func->TextureSamplerPairs()) { | ||||
|                 for (auto pair : fn->TextureSamplerPairs()) { | ||||
|                     const sem::Variable* texture_var = pair.first; | ||||
|                     const sem::Variable* sampler_var = pair.second; | ||||
|                     std::string name = | ||||
| @ -197,23 +202,23 @@ struct CombineSamplers::State { | ||||
|                         auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var); | ||||
|                         auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type); | ||||
|                         params.Push(var); | ||||
|                         function_combined_texture_samplers_[func][pair] = var; | ||||
|                         function_combined_texture_samplers_[fn][pair] = var; | ||||
|                     } | ||||
|                 } | ||||
|                 // Filter out separate textures and samplers from the original
 | ||||
|                 // function signature.
 | ||||
|                 for (auto* var : src->params) { | ||||
|                     if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) { | ||||
|                         params.Push(ctx.Clone(var)); | ||||
|                 for (auto* param : fn->Parameters()) { | ||||
|                     if (!param->Type()->IsAnyOf<sem::Texture, sem::Sampler>()) { | ||||
|                         params.Push(ctx.Clone(param->Declaration())); | ||||
|                     } | ||||
|                 } | ||||
|                 // Create a new function signature that differs only in the parameter
 | ||||
|                 // list.
 | ||||
|                 auto symbol = ctx.Clone(src->symbol); | ||||
|                 auto* return_type = ctx.Clone(src->return_type); | ||||
|                 auto* body = ctx.Clone(src->body); | ||||
|                 auto attributes = ctx.Clone(src->attributes); | ||||
|                 auto return_type_attributes = ctx.Clone(src->return_type_attributes); | ||||
|                 auto symbol = ctx.Clone(ast_fn->symbol); | ||||
|                 auto* return_type = ctx.Clone(ast_fn->return_type); | ||||
|                 auto* body = ctx.Clone(ast_fn->body); | ||||
|                 auto attributes = ctx.Clone(ast_fn->attributes); | ||||
|                 auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes); | ||||
|                 return ctx.dst->create<ast::Function>(symbol, params, return_type, body, | ||||
|                                                       std::move(attributes), | ||||
|                                                       std::move(return_type_attributes)); | ||||
| @ -327,6 +332,7 @@ struct CombineSamplers::State { | ||||
|         }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| @ -334,15 +340,18 @@ CombineSamplers::CombineSamplers() = default; | ||||
| 
 | ||||
| CombineSamplers::~CombineSamplers() = default; | ||||
| 
 | ||||
| void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
| Transform::ApplyResult CombineSamplers::Apply(const Program* src, | ||||
|                                               const DataMap& inputs, | ||||
|                                               DataMap&) const { | ||||
|     auto* binding_info = inputs.Get<BindingInfo>(); | ||||
|     if (!binding_info) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return; | ||||
|         ProgramBuilder b; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     State(ctx, binding_info).Run(); | ||||
|     return State(src, binding_info).Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -88,17 +88,13 @@ class CombineSamplers final : public Castable<CombineSamplers, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~CombineSamplers() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// The PIMPL state for this transform
 | ||||
|     struct State; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -47,6 +47,18 @@ namespace tint::transform { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* decl : program->AST().GlobalDeclarations()) { | ||||
|         if (auto* var = program->Sem().Get<sem::Variable>(decl)) { | ||||
|             if (var->AddressSpace() == ast::AddressSpace::kStorage || | ||||
|                 var->AddressSpace() == ast::AddressSpace::kUniform) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| /// Offset is a simple ast::Expression builder interface, used to build byte
 | ||||
| /// offsets for storage and uniform buffer accesses.
 | ||||
| struct Offset : Castable<Offset> { | ||||
| @ -291,7 +303,7 @@ struct Store { | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// State holds the current transform state
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct DecomposeMemoryAccess::State { | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
| @ -477,7 +489,7 @@ struct DecomposeMemoryAccess::State { | ||||
|                         // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||
|                         //   this method only handles storage and uniform.
 | ||||
|                         // * Runtime-sized arrays are not loadable.
 | ||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                         TINT_ICE(Transform, b.Diagnostics()) | ||||
|                             << "unexpected non-constant array count"; | ||||
|                         arr_cnt = 1; | ||||
|                     } | ||||
| @ -578,7 +590,7 @@ struct DecomposeMemoryAccess::State { | ||||
|                                 // * Override-expression counts can only be applied to workgroup
 | ||||
|                                 //   arrays, and this method only handles storage and uniform.
 | ||||
|                                 // * Runtime-sized arrays are not storable.
 | ||||
|                                 TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                                     << "unexpected non-constant array count"; | ||||
|                                 arr_cnt = 1; | ||||
|                             } | ||||
| @ -808,21 +820,16 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const { | ||||
| DecomposeMemoryAccess::DecomposeMemoryAccess() = default; | ||||
| DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; | ||||
| 
 | ||||
| bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* decl : program->AST().GlobalDeclarations()) { | ||||
|         if (auto* var = program->Sem().Get<sem::Variable>(decl)) { | ||||
|             if (var->AddressSpace() == ast::AddressSpace::kStorage || | ||||
|                 var->AddressSpace() == ast::AddressSpace::kUniform) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
| Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, | ||||
|                                                     const DataMap&, | ||||
|                                                     DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     auto& sem = ctx.src->Sem(); | ||||
| 
 | ||||
|     auto& sem = src->Sem(); | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     State state(ctx); | ||||
| 
 | ||||
|     // Scan the AST nodes for storage and uniform buffer accesses. Complex
 | ||||
| @ -833,7 +840,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con | ||||
|     // Inner-most expression nodes are guaranteed to be visited first because AST
 | ||||
|     // nodes are fully immutable and require their children to be constructed
 | ||||
|     // first so their pointer can be passed to the parent's initializer.
 | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* ident = node->As<ast::IdentifierExpression>()) { | ||||
|             // X
 | ||||
|             if (auto* var = sem.Get<sem::VariableUser>(ident)) { | ||||
| @ -1001,6 +1008,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -108,20 +108,12 @@ class DecomposeMemoryAccess final : public Castable<DecomposeMemoryAccess, Trans | ||||
|     /// Destructor
 | ||||
|     ~DecomposeMemoryAccess() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -34,13 +34,7 @@ namespace { | ||||
| 
 | ||||
| using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| DecomposeStridedArray::DecomposeStridedArray() = default; | ||||
| 
 | ||||
| DecomposeStridedArray::~DecomposeStridedArray() = default; | ||||
| 
 | ||||
| bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* ast = node->As<ast::Array>()) { | ||||
|             if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) { | ||||
| @ -51,8 +45,22 @@ bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) co | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     const auto& sem = ctx.src->Sem(); | ||||
| }  // namespace
 | ||||
| 
 | ||||
| DecomposeStridedArray::DecomposeStridedArray() = default; | ||||
| 
 | ||||
| DecomposeStridedArray::~DecomposeStridedArray() = default; | ||||
| 
 | ||||
| Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src, | ||||
|                                                     const DataMap&, | ||||
|                                                     DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     const auto& sem = src->Sem(); | ||||
| 
 | ||||
|     static constexpr const char* kMemberName = "el"; | ||||
| 
 | ||||
| @ -69,23 +77,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con | ||||
|         if (auto* arr = sem.Get(ast)) { | ||||
|             if (!arr->IsStrideImplicit()) { | ||||
|                 auto el_ty = utils::GetOrCreate(decomposed, arr, [&] { | ||||
|                     auto name = ctx.dst->Symbols().New("strided_arr"); | ||||
|                     auto name = b.Symbols().New("strided_arr"); | ||||
|                     auto* member_ty = ctx.Clone(ast->type); | ||||
|                     auto* member = ctx.dst->Member(kMemberName, member_ty, | ||||
|                                                    utils::Vector{ | ||||
|                                                        ctx.dst->MemberSize(AInt(arr->Stride())), | ||||
|                                                    }); | ||||
|                     ctx.dst->Structure(name, utils::Vector{member}); | ||||
|                     auto* member = b.Member(kMemberName, member_ty, | ||||
|                                             utils::Vector{ | ||||
|                                                 b.MemberSize(AInt(arr->Stride())), | ||||
|                                             }); | ||||
|                     b.Structure(name, utils::Vector{member}); | ||||
|                     return name; | ||||
|                 }); | ||||
|                 auto* count = ctx.Clone(ast->count); | ||||
|                 return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count); | ||||
|                 return b.ty.array(b.ty.type_name(el_ty), count); | ||||
|             } | ||||
|             if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) { | ||||
|                 // Strip the @stride attribute
 | ||||
|                 auto* ty = ctx.Clone(ast->type); | ||||
|                 auto* count = ctx.Clone(ast->count); | ||||
|                 return ctx.dst->ty.array(ty, count); | ||||
|                 return b.ty.array(ty, count); | ||||
|             } | ||||
|         } | ||||
|         return nullptr; | ||||
| @ -96,11 +104,11 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con | ||||
|     // to insert an additional member accessor for the single structure field.
 | ||||
|     // Example: `arr[i]` -> `arr[i].el`
 | ||||
|     ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* { | ||||
|         if (auto* ty = ctx.src->TypeOf(idx->object)) { | ||||
|         if (auto* ty = src->TypeOf(idx->object)) { | ||||
|             if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) { | ||||
|                 if (!arr->IsStrideImplicit()) { | ||||
|                     auto* expr = ctx.CloneWithoutTransform(idx); | ||||
|                     return ctx.dst->MemberAccessor(expr, kMemberName); | ||||
|                     return b.MemberAccessor(expr, kMemberName); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @ -136,21 +144,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con | ||||
|                         if (auto it = decomposed.find(arr); it != decomposed.end()) { | ||||
|                             args.Reserve(expr->args.Length()); | ||||
|                             for (auto* arg : expr->args) { | ||||
|                                 args.Push(ctx.dst->Call(it->second, ctx.Clone(arg))); | ||||
|                                 args.Push(b.Call(it->second, ctx.Clone(arg))); | ||||
|                             } | ||||
|                         } else { | ||||
|                             args = ctx.Clone(expr->args); | ||||
|                         } | ||||
| 
 | ||||
|                         return target.type ? ctx.dst->Construct(target.type, std::move(args)) | ||||
|                                            : ctx.dst->Call(target.name, std::move(args)); | ||||
|                         return target.type ? b.Construct(target.type, std::move(args)) | ||||
|                                            : b.Call(target.name, std::move(args)); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return nullptr; | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -35,19 +35,10 @@ class DecomposeStridedArray final : public Castable<DecomposeStridedArray, Trans | ||||
|     /// Destructor
 | ||||
|     ~DecomposeStridedArray() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -53,24 +53,25 @@ struct MatrixInfo { | ||||
|     }; | ||||
| }; | ||||
| 
 | ||||
| /// Return type of the callback function of GatherCustomStrideMatrixMembers
 | ||||
| enum GatherResult { kContinue, kStop }; | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
 | ||||
| /// storage and uniform structs, which are of a matrix type, and have a custom
 | ||||
| /// matrix stride attribute. For each matrix member found, `callback` is called.
 | ||||
| /// `callback` is a function with the signature:
 | ||||
| ///      GatherResult(const sem::StructMember* member,
 | ||||
| ///                   sem::Matrix* matrix,
 | ||||
| ///                   uint32_t stride)
 | ||||
| /// If `callback` return GatherResult::kStop, then the scanning will immediately
 | ||||
| /// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
 | ||||
| /// scanning will continue.
 | ||||
| template <typename F> | ||||
| void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
| DecomposeStridedMatrix::DecomposeStridedMatrix() = default; | ||||
| 
 | ||||
| DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; | ||||
| 
 | ||||
| Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, | ||||
|                                                      const DataMap&, | ||||
|                                                      DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     // Scan the program for all storage and uniform structure matrix members with
 | ||||
|     // a custom stride attribute. Replace these matrices with an equivalent array,
 | ||||
|     // and populate the `decomposed` map with the members that have been replaced.
 | ||||
|     utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed; | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* str = node->As<ast::Struct>()) { | ||||
|             auto* str_ty = program->Sem().Get(str); | ||||
|             auto* str_ty = src->Sem().Get(str); | ||||
|             if (!str_ty->UsedAs(ast::AddressSpace::kUniform) && | ||||
|                 !str_ty->UsedAs(ast::AddressSpace::kStorage)) { | ||||
|                 continue; | ||||
| @ -89,46 +90,20 @@ void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { | ||||
|                 if (matrix->ColumnStride() == stride) { | ||||
|                     continue; | ||||
|                 } | ||||
|                 if (callback(member, matrix, stride) == GatherResult::kStop) { | ||||
|                     return; | ||||
|                 } | ||||
|                 // We've got ourselves a struct member of a matrix type with a custom
 | ||||
|                 // stride. Replace this with an array of column vectors.
 | ||||
|                 MatrixInfo info{stride, matrix}; | ||||
|                 auto* replacement = | ||||
|                     b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); | ||||
|                 ctx.Replace(member->Declaration(), replacement); | ||||
|                 decomposed.Add(member->Declaration(), info); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| DecomposeStridedMatrix::DecomposeStridedMatrix() = default; | ||||
| 
 | ||||
| DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; | ||||
| 
 | ||||
| bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     bool should_run = false; | ||||
|     GatherCustomStrideMatrixMembers(program, | ||||
|                                     [&](const sem::StructMember*, const sem::Matrix*, uint32_t) { | ||||
|                                         should_run = true; | ||||
|                                         return GatherResult::kStop; | ||||
|                                     }); | ||||
|     return should_run; | ||||
| } | ||||
| 
 | ||||
| void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     // Scan the program for all storage and uniform structure matrix members with
 | ||||
|     // a custom stride attribute. Replace these matrices with an equivalent array,
 | ||||
|     // and populate the `decomposed` map with the members that have been replaced.
 | ||||
|     std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed; | ||||
|     GatherCustomStrideMatrixMembers( | ||||
|         ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) { | ||||
|             // We've got ourselves a struct member of a matrix type with a custom
 | ||||
|             // stride. Replace this with an array of column vectors.
 | ||||
|             MatrixInfo info{stride, matrix}; | ||||
|             auto* replacement = | ||||
|                 ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); | ||||
|             ctx.Replace(member->Declaration(), replacement); | ||||
|             decomposed.emplace(member->Declaration(), info); | ||||
|             return GatherResult::kContinue; | ||||
|         }); | ||||
|     if (decomposed.IsEmpty()) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     // For all expressions where a single matrix column vector was indexed, we can
 | ||||
|     // preserve these without calling conversion functions.
 | ||||
| @ -136,12 +111,11 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co | ||||
|     //   ssbo.mat[2] -> ssbo.mat[2]
 | ||||
|     ctx.ReplaceAll( | ||||
|         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { | ||||
|             if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) { | ||||
|                 auto it = decomposed.find(access->Member()->Declaration()); | ||||
|                 if (it != decomposed.end()) { | ||||
|             if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) { | ||||
|                 if (decomposed.Contains(access->Member()->Declaration())) { | ||||
|                     auto* obj = ctx.CloneWithoutTransform(expr->object); | ||||
|                     auto* idx = ctx.Clone(expr->index); | ||||
|                     return ctx.dst->IndexAccessor(obj, idx); | ||||
|                     return b.IndexAccessor(obj, idx); | ||||
|                 } | ||||
|             } | ||||
|             return nullptr; | ||||
| @ -154,39 +128,36 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co | ||||
|     //   ssbo.mat = mat_to_arr(m)
 | ||||
|     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr; | ||||
|     ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { | ||||
|         if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { | ||||
|             auto it = decomposed.find(access->Member()->Declaration()); | ||||
|             if (it == decomposed.end()) { | ||||
|                 return nullptr; | ||||
|         if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { | ||||
|             if (auto* info = decomposed.Find(access->Member()->Declaration())) { | ||||
|                 auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { | ||||
|                     auto name = | ||||
|                         b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" + | ||||
|                                         std::to_string(info->matrix->rows()) + "_stride_" + | ||||
|                                         std::to_string(info->stride) + "_to_arr"); | ||||
| 
 | ||||
|                     auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; | ||||
|                     auto array = [&] { return info->array(ctx.dst); }; | ||||
| 
 | ||||
|                     auto mat = b.Sym("m"); | ||||
|                     utils::Vector<const ast::Expression*, 4> columns; | ||||
|                     for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { | ||||
|                         columns.Push(b.IndexAccessor(mat, u32(i))); | ||||
|                     } | ||||
|                     b.Func(name, | ||||
|                            utils::Vector{ | ||||
|                                b.Param(mat, matrix()), | ||||
|                            }, | ||||
|                            array(), | ||||
|                            utils::Vector{ | ||||
|                                b.Return(b.Construct(array(), columns)), | ||||
|                            }); | ||||
|                     return name; | ||||
|                 }); | ||||
|                 auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); | ||||
|                 auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs)); | ||||
|                 return b.Assign(lhs, rhs); | ||||
|             } | ||||
|             MatrixInfo info = it->second; | ||||
|             auto fn = utils::GetOrCreate(mat_to_arr, info, [&] { | ||||
|                 auto name = | ||||
|                     ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" + | ||||
|                                            std::to_string(info.matrix->rows()) + "_stride_" + | ||||
|                                            std::to_string(info.stride) + "_to_arr"); | ||||
| 
 | ||||
|                 auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; | ||||
|                 auto array = [&] { return info.array(ctx.dst); }; | ||||
| 
 | ||||
|                 auto mat = ctx.dst->Sym("m"); | ||||
|                 utils::Vector<const ast::Expression*, 4> columns; | ||||
|                 for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) { | ||||
|                     columns.Push(ctx.dst->IndexAccessor(mat, u32(i))); | ||||
|                 } | ||||
|                 ctx.dst->Func(name, | ||||
|                               utils::Vector{ | ||||
|                                   ctx.dst->Param(mat, matrix()), | ||||
|                               }, | ||||
|                               array(), | ||||
|                               utils::Vector{ | ||||
|                                   ctx.dst->Return(ctx.dst->Construct(array(), columns)), | ||||
|                               }); | ||||
|                 return name; | ||||
|             }); | ||||
|             auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); | ||||
|             auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs)); | ||||
|             return ctx.dst->Assign(lhs, rhs); | ||||
|         } | ||||
|         return nullptr; | ||||
|     }); | ||||
| @ -196,41 +167,40 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co | ||||
|     //   m = arr_to_mat(ssbo.mat)
 | ||||
|     std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat; | ||||
|     ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { | ||||
|         if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) { | ||||
|             auto it = decomposed.find(access->Member()->Declaration()); | ||||
|             if (it == decomposed.end()) { | ||||
|                 return nullptr; | ||||
|         if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) { | ||||
|             if (auto* info = decomposed.Find(access->Member()->Declaration())) { | ||||
|                 auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { | ||||
|                     auto name = | ||||
|                         b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) + | ||||
|                                         "x" + std::to_string(info->matrix->rows()) + "_stride_" + | ||||
|                                         std::to_string(info->stride)); | ||||
| 
 | ||||
|                     auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; | ||||
|                     auto array = [&] { return info->array(ctx.dst); }; | ||||
| 
 | ||||
|                     auto arr = b.Sym("arr"); | ||||
|                     utils::Vector<const ast::Expression*, 4> columns; | ||||
|                     for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { | ||||
|                         columns.Push(b.IndexAccessor(arr, u32(i))); | ||||
|                     } | ||||
|                     b.Func(name, | ||||
|                            utils::Vector{ | ||||
|                                b.Param(arr, array()), | ||||
|                            }, | ||||
|                            matrix(), | ||||
|                            utils::Vector{ | ||||
|                                b.Return(b.Construct(matrix(), columns)), | ||||
|                            }); | ||||
|                     return name; | ||||
|                 }); | ||||
|                 return b.Call(fn, ctx.CloneWithoutTransform(expr)); | ||||
|             } | ||||
|             MatrixInfo info = it->second; | ||||
|             auto fn = utils::GetOrCreate(arr_to_mat, info, [&] { | ||||
|                 auto name = ctx.dst->Symbols().New( | ||||
|                     "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" + | ||||
|                     std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride)); | ||||
| 
 | ||||
|                 auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; | ||||
|                 auto array = [&] { return info.array(ctx.dst); }; | ||||
| 
 | ||||
|                 auto arr = ctx.dst->Sym("arr"); | ||||
|                 utils::Vector<const ast::Expression*, 4> columns; | ||||
|                 for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) { | ||||
|                     columns.Push(ctx.dst->IndexAccessor(arr, u32(i))); | ||||
|                 } | ||||
|                 ctx.dst->Func(name, | ||||
|                               utils::Vector{ | ||||
|                                   ctx.dst->Param(arr, array()), | ||||
|                               }, | ||||
|                               matrix(), | ||||
|                               utils::Vector{ | ||||
|                                   ctx.dst->Return(ctx.dst->Construct(matrix(), columns)), | ||||
|                               }); | ||||
|                 return name; | ||||
|             }); | ||||
|             return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr)); | ||||
|         } | ||||
|         return nullptr; | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -35,19 +35,10 @@ class DecomposeStridedMatrix final : public Castable<DecomposeStridedMatrix, Tra | ||||
|     /// Destructor
 | ||||
|     ~DecomposeStridedMatrix() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -27,14 +27,20 @@ DisableUniformityAnalysis::DisableUniformityAnalysis() = default; | ||||
| 
 | ||||
| DisableUniformityAnalysis::~DisableUniformityAnalysis() = default; | ||||
| 
 | ||||
| bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     return !program->Sem().Module()->Extensions().Contains( | ||||
|         ast::Extension::kChromiumDisableUniformityAnalysis); | ||||
| } | ||||
| Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src, | ||||
|                                                         const DataMap&, | ||||
|                                                         DataMap&) const { | ||||
|     if (src->Sem().Module()->Extensions().Contains( | ||||
|             ast::Extension::kChromiumDisableUniformityAnalysis)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     b.Enable(ast::Extension::kChromiumDisableUniformityAnalysis); | ||||
| 
 | ||||
| void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis); | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -27,19 +27,10 @@ class DisableUniformityAnalysis final : public Castable<DisableUniformityAnalysi | ||||
|     /// Destructor
 | ||||
|     ~DisableUniformityAnalysis() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -31,11 +31,9 @@ using namespace tint::number_suffixes;  // NOLINT | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| ExpandCompoundAssignment::ExpandCompoundAssignment() = default; | ||||
| namespace { | ||||
| 
 | ||||
| ExpandCompoundAssignment::~ExpandCompoundAssignment() = default; | ||||
| 
 | ||||
| bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) { | ||||
|             return true; | ||||
| @ -44,21 +42,10 @@ bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| namespace { | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// Internal class used to collect statement expansions during the transform.
 | ||||
| class State { | ||||
|   private: | ||||
|     /// The clone context.
 | ||||
|     CloneContext& ctx; | ||||
| 
 | ||||
|     /// The program builder.
 | ||||
|     ProgramBuilder& b; | ||||
| 
 | ||||
|     /// The HoistToDeclBefore helper instance.
 | ||||
|     HoistToDeclBefore hoist_to_decl_before; | ||||
| 
 | ||||
|   public: | ||||
| /// PIMPL state for the transform
 | ||||
| struct ExpandCompoundAssignment::State { | ||||
|     /// Constructor
 | ||||
|     /// @param context the clone context
 | ||||
|     explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {} | ||||
| @ -158,15 +145,32 @@ class State { | ||||
|         ctx.Replace(stmt, b.Assign(new_lhs(), value)); | ||||
|     } | ||||
| 
 | ||||
|     /// Finalize the transformation and clone the module.
 | ||||
|     void Finalize() { ctx.Clone(); } | ||||
|   private: | ||||
|     /// The clone context.
 | ||||
|     CloneContext& ctx; | ||||
| 
 | ||||
|     /// The program builder.
 | ||||
|     ProgramBuilder& b; | ||||
| 
 | ||||
|     /// The HoistToDeclBefore helper instance.
 | ||||
|     HoistToDeclBefore hoist_to_decl_before; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| ExpandCompoundAssignment::ExpandCompoundAssignment() = default; | ||||
| 
 | ||||
| void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| ExpandCompoundAssignment::~ExpandCompoundAssignment() = default; | ||||
| 
 | ||||
| Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src, | ||||
|                                                        const DataMap&, | ||||
|                                                        DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     State state(ctx); | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) { | ||||
|             state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op); | ||||
|         } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) { | ||||
| @ -175,7 +179,9 @@ void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) | ||||
|             state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op); | ||||
|         } | ||||
|     } | ||||
|     state.Finalize(); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -45,19 +45,13 @@ class ExpandCompoundAssignment final : public Castable<ExpandCompoundAssignment, | ||||
|     /// Destructor
 | ||||
|     ~ExpandCompoundAssignment() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -35,6 +35,15 @@ namespace { | ||||
| constexpr char kFirstVertexName[] = "first_vertex_index"; | ||||
| constexpr char kFirstInstanceName[] = "first_instance_index"; | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (fn->PipelineStage() == ast::PipelineStage::kVertex) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| FirstIndexOffset::BindingPoint::BindingPoint() = default; | ||||
| @ -49,16 +58,16 @@ FirstIndexOffset::Data::~Data() = default; | ||||
| FirstIndexOffset::FirstIndexOffset() = default; | ||||
| FirstIndexOffset::~FirstIndexOffset() = default; | ||||
| 
 | ||||
| bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (fn->PipelineStage() == ast::PipelineStage::kVertex) { | ||||
|             return true; | ||||
|         } | ||||
| Transform::ApplyResult FirstIndexOffset::Apply(const Program* src, | ||||
|                                                const DataMap& inputs, | ||||
|                                                DataMap& outputs) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     // Get the uniform buffer binding point
 | ||||
|     uint32_t ub_binding = binding_; | ||||
|     uint32_t ub_group = group_; | ||||
| @ -115,17 +124,17 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou | ||||
|     if (has_vertex_or_instance_index) { | ||||
|         // Add uniform buffer members and calculate byte offsets
 | ||||
|         utils::Vector<const ast::StructMember*, 8> members; | ||||
|         members.Push(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32())); | ||||
|         members.Push(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32())); | ||||
|         auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members)); | ||||
|         members.Push(b.Member(kFirstVertexName, b.ty.u32())); | ||||
|         members.Push(b.Member(kFirstInstanceName, b.ty.u32())); | ||||
|         auto* struct_ = b.Structure(b.Sym(), std::move(members)); | ||||
| 
 | ||||
|         // Create a global to hold the uniform buffer
 | ||||
|         Symbol buffer_name = ctx.dst->Sym(); | ||||
|         ctx.dst->GlobalVar(buffer_name, ctx.dst->ty.Of(struct_), ast::AddressSpace::kUniform, | ||||
|                            utils::Vector{ | ||||
|                                ctx.dst->Binding(AInt(ub_binding)), | ||||
|                                ctx.dst->Group(AInt(ub_group)), | ||||
|                            }); | ||||
|         Symbol buffer_name = b.Sym(); | ||||
|         b.GlobalVar(buffer_name, b.ty.Of(struct_), ast::AddressSpace::kUniform, | ||||
|                     utils::Vector{ | ||||
|                         b.Binding(AInt(ub_binding)), | ||||
|                         b.Group(AInt(ub_group)), | ||||
|                     }); | ||||
| 
 | ||||
|         // Fix up all references to the builtins with the offsets
 | ||||
|         ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* { | ||||
| @ -150,9 +159,10 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
| 
 | ||||
|     outputs.Add<Data>(has_vertex_or_instance_index); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -103,19 +103,10 @@ class FirstIndexOffset final : public Castable<FirstIndexOffset, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~FirstIndexOffset() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     uint32_t binding_ = 0; | ||||
|  | ||||
| @ -14,17 +14,17 @@ | ||||
| 
 | ||||
| #include "src/tint/transform/for_loop_to_loop.h" | ||||
| 
 | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/ast/break_statement.h" | ||||
| #include "src/tint/program_builder.h" | ||||
| 
 | ||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| ForLoopToLoop::ForLoopToLoop() = default; | ||||
| namespace { | ||||
| 
 | ||||
| ForLoopToLoop::~ForLoopToLoop() = default; | ||||
| 
 | ||||
| bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (node->Is<ast::ForLoopStatement>()) { | ||||
|             return true; | ||||
| @ -33,19 +33,31 @@ bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| }  // namespace
 | ||||
| 
 | ||||
| ForLoopToLoop::ForLoopToLoop() = default; | ||||
| 
 | ||||
| ForLoopToLoop::~ForLoopToLoop() = default; | ||||
| 
 | ||||
| Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { | ||||
|         utils::Vector<const ast::Statement*, 8> stmts; | ||||
|         if (auto* cond = for_loop->condition) { | ||||
|             // !condition
 | ||||
|             auto* not_cond = | ||||
|                 ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond)); | ||||
|             auto* not_cond = b.Not(ctx.Clone(cond)); | ||||
| 
 | ||||
|             // { break; }
 | ||||
|             auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>()); | ||||
|             auto* break_body = b.Block(b.Break()); | ||||
| 
 | ||||
|             // if (!condition) { break; }
 | ||||
|             stmts.Push(ctx.dst->If(not_cond, break_body)); | ||||
|             stmts.Push(b.If(not_cond, break_body)); | ||||
|         } | ||||
|         for (auto* stmt : for_loop->body->statements) { | ||||
|             stmts.Push(ctx.Clone(stmt)); | ||||
| @ -53,20 +65,21 @@ void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| 
 | ||||
|         const ast::BlockStatement* continuing = nullptr; | ||||
|         if (auto* cont = for_loop->continuing) { | ||||
|             continuing = ctx.dst->Block(ctx.Clone(cont)); | ||||
|             continuing = b.Block(ctx.Clone(cont)); | ||||
|         } | ||||
| 
 | ||||
|         auto* body = ctx.dst->Block(stmts); | ||||
|         auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing); | ||||
|         auto* body = b.Block(stmts); | ||||
|         auto* loop = b.Loop(body, continuing); | ||||
| 
 | ||||
|         if (auto* init = for_loop->initializer) { | ||||
|             return ctx.dst->Block(ctx.Clone(init), loop); | ||||
|             return b.Block(ctx.Clone(init), loop); | ||||
|         } | ||||
| 
 | ||||
|         return loop; | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -29,19 +29,10 @@ class ForLoopToLoop final : public Castable<ForLoopToLoop, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~ForLoopToLoop() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -32,70 +32,15 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// Private implementation of LocalizeStructArrayAssignment transform
 | ||||
| class LocalizeStructArrayAssignment::State { | ||||
|   private: | ||||
|     CloneContext& ctx; | ||||
|     ProgramBuilder& b; | ||||
| 
 | ||||
|     /// Returns true if `expr` contains an index accessor expression to a
 | ||||
|     /// structure member of array type.
 | ||||
|     bool ContainsStructArrayIndex(const ast::Expression* expr) { | ||||
|         bool result = false; | ||||
|         ast::TraverseExpressions( | ||||
|             expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { | ||||
|                 // Indexing using a runtime value?
 | ||||
|                 auto* idx_sem = ctx.src->Sem().Get(ia->index); | ||||
|                 if (!idx_sem->ConstantValue()) { | ||||
|                     // Indexing a member access expr?
 | ||||
|                     if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) { | ||||
|                         // That accesses an array?
 | ||||
|                         if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) { | ||||
|                             result = true; | ||||
|                             return ast::TraverseAction::Stop; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 return ast::TraverseAction::Descend; | ||||
|             }); | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     // Returns the type and address space of the originating variable of the lhs
 | ||||
|     // of the assignment statement.
 | ||||
|     // See https://www.w3.org/TR/WGSL/#originating-variable-section
 | ||||
|     std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace( | ||||
|         const ast::AssignmentStatement* assign_stmt) { | ||||
|         auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable(); | ||||
|         if (!source_var) { | ||||
|             TINT_ICE(Transform, b.Diagnostics()) | ||||
|                 << "Unable to determine originating variable for lhs of assignment " | ||||
|                    "statement"; | ||||
|             return {}; | ||||
|         } | ||||
| 
 | ||||
|         auto* type = source_var->Type(); | ||||
|         if (auto* ref = type->As<sem::Reference>()) { | ||||
|             return {ref->StoreType(), ref->AddressSpace()}; | ||||
|         } else if (auto* ptr = type->As<sem::Pointer>()) { | ||||
|             return {ptr->StoreType(), ptr->AddressSpace()}; | ||||
|         } | ||||
| 
 | ||||
|         TINT_ICE(Transform, b.Diagnostics()) | ||||
|             << "Expecting to find variable of type pointer or reference on lhs " | ||||
|                "of assignment statement"; | ||||
|         return {}; | ||||
|     } | ||||
| 
 | ||||
|   public: | ||||
| /// PIMPL state for the transform
 | ||||
| struct LocalizeStructArrayAssignment::State { | ||||
|     /// Constructor
 | ||||
|     /// @param ctx_in the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         struct Shared { | ||||
|             bool process_nested_nodes = false; | ||||
|             utils::Vector<const ast::Statement*, 4> insert_before_stmts; | ||||
| @ -189,6 +134,65 @@ class LocalizeStructArrayAssignment::State { | ||||
|             }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     /// Returns true if `expr` contains an index accessor expression to a
 | ||||
|     /// structure member of array type.
 | ||||
|     bool ContainsStructArrayIndex(const ast::Expression* expr) { | ||||
|         bool result = false; | ||||
|         ast::TraverseExpressions( | ||||
|             expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { | ||||
|                 // Indexing using a runtime value?
 | ||||
|                 auto* idx_sem = src->Sem().Get(ia->index); | ||||
|                 if (!idx_sem->ConstantValue()) { | ||||
|                     // Indexing a member access expr?
 | ||||
|                     if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) { | ||||
|                         // That accesses an array?
 | ||||
|                         if (src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) { | ||||
|                             result = true; | ||||
|                             return ast::TraverseAction::Stop; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 return ast::TraverseAction::Descend; | ||||
|             }); | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     // Returns the type and address space of the originating variable of the lhs
 | ||||
|     // of the assignment statement.
 | ||||
|     // See https://www.w3.org/TR/WGSL/#originating-variable-section
 | ||||
|     std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace( | ||||
|         const ast::AssignmentStatement* assign_stmt) { | ||||
|         auto* source_var = src->Sem().Get(assign_stmt->lhs)->SourceVariable(); | ||||
|         if (!source_var) { | ||||
|             TINT_ICE(Transform, b.Diagnostics()) | ||||
|                 << "Unable to determine originating variable for lhs of assignment " | ||||
|                    "statement"; | ||||
|             return {}; | ||||
|         } | ||||
| 
 | ||||
|         auto* type = source_var->Type(); | ||||
|         if (auto* ref = type->As<sem::Reference>()) { | ||||
|             return {ref->StoreType(), ref->AddressSpace()}; | ||||
|         } else if (auto* ptr = type->As<sem::Pointer>()) { | ||||
|             return {ptr->StoreType(), ptr->AddressSpace()}; | ||||
|         } | ||||
| 
 | ||||
|         TINT_ICE(Transform, b.Diagnostics()) | ||||
|             << "Expecting to find variable of type pointer or reference on lhs " | ||||
|                "of assignment statement"; | ||||
|         return {}; | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| @ -196,9 +200,10 @@ LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default; | ||||
| 
 | ||||
| LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default; | ||||
| 
 | ||||
| void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State state(ctx); | ||||
|     state.Run(); | ||||
| Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src, | ||||
|                                                             const DataMap&, | ||||
|                                                             DataMap&) const { | ||||
|     return State{src}.Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -36,17 +36,13 @@ class LocalizeStructArrayAssignment final | ||||
|     /// Destructor
 | ||||
|     ~LocalizeStructArrayAssignment() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     class State; | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -31,9 +31,9 @@ namespace tint::transform { | ||||
| Manager::Manager() = default; | ||||
| Manager::~Manager() = default; | ||||
| 
 | ||||
| Output Manager::Run(const Program* program, const DataMap& data) const { | ||||
|     const Program* in = program; | ||||
| 
 | ||||
| Transform::ApplyResult Manager::Apply(const Program* program, | ||||
|                                       const DataMap& inputs, | ||||
|                                       DataMap& outputs) const { | ||||
| #if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM | ||||
|     auto print_program = [&](const char* msg, const Transform* transform) { | ||||
|         auto wgsl = Program::printer(in); | ||||
| @ -46,34 +46,30 @@ Output Manager::Run(const Program* program, const DataMap& data) const { | ||||
|     }; | ||||
| #endif | ||||
| 
 | ||||
|     Output out; | ||||
|     std::optional<Program> output; | ||||
| 
 | ||||
|     for (const auto& transform : transforms_) { | ||||
|         if (!transform->ShouldRun(in, data)) { | ||||
|             TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name | ||||
|                                             << std::endl); | ||||
|             continue; | ||||
|         } | ||||
|         TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get())); | ||||
| 
 | ||||
|         auto res = transform->Run(in, data); | ||||
|         out.program = std::move(res.program); | ||||
|         out.data.Add(std::move(res.data)); | ||||
|         in = &out.program; | ||||
|         if (!in->IsValid()) { | ||||
|             TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); | ||||
|             return out; | ||||
|         } | ||||
|         if (auto result = transform->Apply(program, inputs, outputs)) { | ||||
|             output.emplace(std::move(result.value())); | ||||
|             program = &output.value(); | ||||
| 
 | ||||
|         if (transform == transforms_.back()) { | ||||
|             TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); | ||||
|             if (!program->IsValid()) { | ||||
|                 TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); | ||||
|                 break; | ||||
|             } | ||||
| 
 | ||||
|             if (transform == transforms_.back()) { | ||||
|                 TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); | ||||
|             } | ||||
|         } else { | ||||
|             TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name | ||||
|                                             << std::endl); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (program == in) { | ||||
|         out.program = program->Clone(); | ||||
|     } | ||||
| 
 | ||||
|     return out; | ||||
|     return output; | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -47,11 +47,10 @@ class Manager final : public Castable<Manager, Transform> { | ||||
|         transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...)); | ||||
|     } | ||||
| 
 | ||||
|     /// Runs the transforms on `program`, returning the transformation result.
 | ||||
|     /// @param program the source program to transform
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns the transformed program and diagnostics
 | ||||
|     Output Run(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     std::vector<std::unique_ptr<Transform>> transforms_; | ||||
|  | ||||
| @ -65,15 +65,6 @@ MergeReturn::MergeReturn() = default; | ||||
| 
 | ||||
| MergeReturn::~MergeReturn() = default; | ||||
| 
 | ||||
| bool MergeReturn::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* func : program->AST().Functions()) { | ||||
|         if (NeedsTransform(program, func)) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| /// Internal class used to during the transform.
 | ||||
| @ -223,7 +214,12 @@ class State { | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     bool made_changes = false; | ||||
| 
 | ||||
|     for (auto* func : ctx.src->AST().Functions()) { | ||||
|         if (!NeedsTransform(ctx.src, func)) { | ||||
|             continue; | ||||
| @ -231,9 +227,15 @@ void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| 
 | ||||
|         State state(ctx, func); | ||||
|         state.ProcessStatement(func->body); | ||||
|         made_changes = true; | ||||
|     } | ||||
| 
 | ||||
|     if (!made_changes) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -27,19 +27,10 @@ class MergeReturn final : public Castable<MergeReturn, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~MergeReturn() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -38,6 +38,15 @@ using WorkgroupParameterMemberList = utils::Vector<const ast::StructMember*, 8>; | ||||
| // The name of the struct member for arrays that are wrapped in structures.
 | ||||
| const char* kWrappedArrayMemberName = "arr"; | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* decl : program->AST().GlobalDeclarations()) { | ||||
|         if (decl->Is<ast::Variable>()) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| // Returns `true` if `type` is or contains a matrix type.
 | ||||
| bool ContainsMatrix(const sem::Type* type) { | ||||
|     type = type->UnwrapRef(); | ||||
| @ -56,7 +65,7 @@ bool ContainsMatrix(const sem::Type* type) { | ||||
| } | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// State holds the current transform state.
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct ModuleScopeVarToEntryPointParam::State { | ||||
|     /// The clone context.
 | ||||
|     CloneContext& ctx; | ||||
| @ -501,19 +510,20 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default; | ||||
| 
 | ||||
| ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; | ||||
| 
 | ||||
| bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* decl : program->AST().GlobalDeclarations()) { | ||||
|         if (decl->Is<ast::Variable>()) { | ||||
|             return true; | ||||
|         } | ||||
| Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src, | ||||
|                                                               const DataMap&, | ||||
|                                                               DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     State state{ctx}; | ||||
|     state.Process(); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -69,20 +69,12 @@ class ModuleScopeVarToEntryPointParam final | ||||
|     /// Destructor
 | ||||
|     ~ModuleScopeVarToEntryPointParam() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -31,6 +31,17 @@ using namespace tint::number_suffixes;  // NOLINT | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* ty = node->As<ast::Type>()) { | ||||
|             if (program->Sem().Get<sem::ExternalTexture>(ty)) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| /// This struct stores symbols for new bindings created as a result of transforming a
 | ||||
| /// texture_external instance.
 | ||||
| struct NewBindingSymbols { | ||||
| @ -40,7 +51,7 @@ struct NewBindingSymbols { | ||||
| }; | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// State holds the current transform state
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct MultiplanarExternalTexture::State { | ||||
|     /// The clone context.
 | ||||
|     CloneContext& ctx; | ||||
| @ -537,30 +548,26 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default; | ||||
| MultiplanarExternalTexture::MultiplanarExternalTexture() = default; | ||||
| MultiplanarExternalTexture::~MultiplanarExternalTexture() = default; | ||||
| 
 | ||||
| bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* ty = node->As<ast::Type>()) { | ||||
|             if (program->Sem().Get<sem::ExternalTexture>(ty)) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| // Within this transform, an instance of a texture_external binding is unpacked into two
 | ||||
| // texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
 | ||||
| // buffer binding representing a struct of parameters. Calls to texture builtins that contain a
 | ||||
| // texture_external parameter will be transformed into a newly generated version of the function,
 | ||||
| // which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
 | ||||
| void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
| Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src, | ||||
|                                                          const DataMap& inputs, | ||||
|                                                          DataMap&) const { | ||||
|     auto* new_binding_points = inputs.Get<NewBindingPoints>(); | ||||
| 
 | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     if (!new_binding_points) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, | ||||
|             "missing new binding point data for " + std::string(TypeInfo().name)); | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " + | ||||
|                                                                std::string(TypeInfo().name)); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     State state(ctx, new_binding_points); | ||||
| @ -568,6 +575,7 @@ void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, D | ||||
|     state.Process(); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -80,21 +80,13 @@ class MultiplanarExternalTexture final : public Castable<MultiplanarExternalText | ||||
|     /// Destructor
 | ||||
|     ~MultiplanarExternalTexture() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   protected: | ||||
|   private: | ||||
|     struct State; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -23,7 +23,11 @@ using MultiplanarExternalTextureTest = TransformTest; | ||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) { | ||||
|     auto* src = R"()"; | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src)); | ||||
|     DataMap data; | ||||
|     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||
|         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { | ||||
| @ -31,14 +35,22 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { | ||||
| type ET = texture_external; | ||||
| )"; | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src)); | ||||
|     DataMap data; | ||||
|     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||
|         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||
| } | ||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) { | ||||
|     auto* src = R"( | ||||
| @group(0) @binding(0) var ext_tex : texture_external; | ||||
| )"; | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src)); | ||||
|     DataMap data; | ||||
|     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||
|         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { | ||||
| @ -46,7 +58,11 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { | ||||
| fn f(ext_tex : texture_external) {} | ||||
| )"; | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src)); | ||||
|     DataMap data; | ||||
|     data.Add<MultiplanarExternalTexture::NewBindingPoints>( | ||||
|         MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data)); | ||||
| } | ||||
| 
 | ||||
| // Running the transform without passing in data for the new bindings should result in an error.
 | ||||
|  | ||||
| @ -29,6 +29,18 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* attr = node->As<ast::BuiltinAttribute>()) { | ||||
|             if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| /// Accessor describes the identifiers used in a member accessor that is being
 | ||||
| /// used to retrieve the num_workgroups builtin from a parameter.
 | ||||
| struct Accessor { | ||||
| @ -44,41 +56,40 @@ struct Accessor { | ||||
|         size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); } | ||||
|     }; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default; | ||||
| NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default; | ||||
| 
 | ||||
| bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* attr = node->As<ast::BuiltinAttribute>()) { | ||||
|             if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src, | ||||
|                                                        const DataMap& inputs, | ||||
|                                                        DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
| void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
|     auto* cfg = inputs.Get<Config>(); | ||||
|     if (cfg == nullptr) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     const char* kNumWorkgroupsMemberName = "num_workgroups"; | ||||
| 
 | ||||
|     // Find all entry point parameters that declare the num_workgroups builtin.
 | ||||
|     std::unordered_set<Accessor, Accessor::Hasher> to_replace; | ||||
|     for (auto* func : ctx.src->AST().Functions()) { | ||||
|     for (auto* func : src->AST().Functions()) { | ||||
|         // num_workgroups is only valid for compute stages.
 | ||||
|         if (func->PipelineStage() != ast::PipelineStage::kCompute) { | ||||
|             continue; | ||||
|         } | ||||
| 
 | ||||
|         for (auto* param : ctx.src->Sem().Get(func)->Parameters()) { | ||||
|         for (auto* param : src->Sem().Get(func)->Parameters()) { | ||||
|             // Because the CanonicalizeEntryPointIO transform has been run, builtins
 | ||||
|             // will only appear as struct members.
 | ||||
|             auto* str = param->Type()->As<sem::Struct>(); | ||||
| @ -108,7 +119,7 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|                 // If this is the only member, remove the struct and parameter too.
 | ||||
|                 if (str->Members().size() == 1) { | ||||
|                     ctx.Remove(func->params, param->Declaration()); | ||||
|                     ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration()); | ||||
|                     ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration()); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @ -119,11 +130,10 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|     const ast::Variable* num_workgroups_ubo = nullptr; | ||||
|     auto get_ubo = [&]() { | ||||
|         if (!num_workgroups_ubo) { | ||||
|             auto* num_workgroups_struct = ctx.dst->Structure( | ||||
|                 ctx.dst->Sym(), | ||||
|                 utils::Vector{ | ||||
|                     ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32())), | ||||
|                 }); | ||||
|             auto* num_workgroups_struct = | ||||
|                 b.Structure(b.Sym(), utils::Vector{ | ||||
|                                          b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())), | ||||
|                                      }); | ||||
| 
 | ||||
|             uint32_t group, binding; | ||||
|             if (cfg->ubo_binding.has_value()) { | ||||
| @ -135,9 +145,9 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|                 // plus 1, or group 0 if no resource bound.
 | ||||
|                 group = 0; | ||||
| 
 | ||||
|                 for (auto* global : ctx.src->AST().GlobalVariables()) { | ||||
|                 for (auto* global : src->AST().GlobalVariables()) { | ||||
|                     if (global->HasBindingPoint()) { | ||||
|                         auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(global); | ||||
|                         auto* global_sem = src->Sem().Get<sem::GlobalVariable>(global); | ||||
|                         auto binding_point = global_sem->BindingPoint(); | ||||
|                         if (binding_point.group >= group) { | ||||
|                             group = binding_point.group + 1; | ||||
| @ -148,16 +158,16 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|                 binding = 0; | ||||
|             } | ||||
| 
 | ||||
|             num_workgroups_ubo = ctx.dst->GlobalVar( | ||||
|                 ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform, | ||||
|                 ctx.dst->Group(AInt(group)), ctx.dst->Binding(AInt(binding))); | ||||
|             num_workgroups_ubo = | ||||
|                 b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform, | ||||
|                             b.Group(AInt(group)), b.Binding(AInt(binding))); | ||||
|         } | ||||
|         return num_workgroups_ubo; | ||||
|     }; | ||||
| 
 | ||||
|     // Now replace all the places where the builtins are accessed with the value
 | ||||
|     // loaded from the uniform buffer.
 | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         auto* accessor = node->As<ast::MemberAccessorExpression>(); | ||||
|         if (!accessor) { | ||||
|             continue; | ||||
| @ -168,12 +178,12 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat | ||||
|         } | ||||
| 
 | ||||
|         if (to_replace.count({ident->symbol, accessor->member->symbol})) { | ||||
|             ctx.Replace(accessor, | ||||
|                         ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName)); | ||||
|             ctx.Replace(accessor, b.MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp) | ||||
|  | ||||
| @ -72,19 +72,10 @@ class NumWorkgroupsFromUniform final : public Castable<NumWorkgroupsFromUniform, | ||||
|         std::optional<sem::BindingPoint> ubo_binding; | ||||
|     }; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -28,7 +28,9 @@ using NumWorkgroupsFromUniformTest = TransformTest; | ||||
| TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) { | ||||
|     auto* src = R"()"; | ||||
| 
 | ||||
|     EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src)); | ||||
|     DataMap data; | ||||
|     data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u}); | ||||
|     EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) { | ||||
| @ -38,7 +40,9 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) { | ||||
| } | ||||
| )"; | ||||
| 
 | ||||
|     EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src)); | ||||
|     DataMap data; | ||||
|     data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u}); | ||||
|     EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src, data)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { | ||||
| @ -55,7 +59,6 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) { | ||||
|     DataMap data; | ||||
|     data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl); | ||||
|     auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data); | ||||
| 
 | ||||
|     EXPECT_EQ(expect, str(got)); | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -33,14 +33,15 @@ using namespace tint::number_suffixes;  // NOLINT | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// The PIMPL state for the PackedVec3 transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct PackedVec3::State { | ||||
|     /// Constructor
 | ||||
|     /// @param c the CloneContext
 | ||||
|     explicit State(CloneContext& c) : ctx(c) {} | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         // Packed vec3<T> struct members
 | ||||
|         utils::Hashset<const sem::StructMember*, 8> members; | ||||
| 
 | ||||
| @ -72,6 +73,10 @@ struct PackedVec3::State { | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if (members.IsEmpty()) { | ||||
|             return SkipTransform; | ||||
|         } | ||||
| 
 | ||||
|         // Walk the nodes, starting with the most deeply nested, finding all the AST expressions
 | ||||
|         // that load a whole packed vector (not a scalar / swizzle of the vector).
 | ||||
|         utils::Hashset<const sem::Expression*, 16> refs; | ||||
| @ -137,36 +142,20 @@ struct PackedVec3::State { | ||||
|         } | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|     } | ||||
| 
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     /// @param program the program to inspect
 | ||||
|     static bool ShouldRun(const Program* program) { | ||||
|         for (auto* decl : program->AST().GlobalDeclarations()) { | ||||
|             if (auto* str = program->Sem().Get<sem::Struct>(decl)) { | ||||
|                 if (str->IsHostShareable()) { | ||||
|                     for (auto* member : str->Members()) { | ||||
|                         if (auto* vec = member->Type()->As<sem::Vector>()) { | ||||
|                             if (vec->Width() == 3) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return false; | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
|     /// Alias to the semantic info in ctx.src
 | ||||
|     const sem::Info& sem = ctx.src->Sem(); | ||||
|     /// Alias to the symbols in ctx.src
 | ||||
|     const SymbolTable& sym = ctx.src->Symbols(); | ||||
|     /// Alias to the ctx.dst program builder
 | ||||
|     ProgramBuilder& b = *ctx.dst; | ||||
| }; | ||||
| 
 | ||||
| PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {} | ||||
| @ -183,12 +172,8 @@ std::string PackedVec3::Attribute::InternalName() const { | ||||
| PackedVec3::PackedVec3() = default; | ||||
| PackedVec3::~PackedVec3() = default; | ||||
| 
 | ||||
| bool PackedVec3::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     return State::ShouldRun(program); | ||||
| } | ||||
| 
 | ||||
| void PackedVec3::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State(ctx).Run(); | ||||
| Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     return State{src}.Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -56,21 +56,13 @@ class PackedVec3 final : public Castable<PackedVec3, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~PackedVec3() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -50,8 +50,10 @@ PadStructs::PadStructs() = default; | ||||
| 
 | ||||
| PadStructs::~PadStructs() = default; | ||||
| 
 | ||||
| void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     auto& sem = ctx.src->Sem(); | ||||
| Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|     auto& sem = src->Sem(); | ||||
| 
 | ||||
|     std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs; | ||||
|     utils::Hashset<const ast::StructMember*, 8> padding_members; | ||||
| @ -65,7 +67,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|         bool has_runtime_sized_array = false; | ||||
|         utils::Vector<const ast::StructMember*, 8> new_members; | ||||
|         for (auto* mem : str->Members()) { | ||||
|             auto name = ctx.src->Symbols().NameFor(mem->Name()); | ||||
|             auto name = src->Symbols().NameFor(mem->Name()); | ||||
| 
 | ||||
|             if (offset < mem->Offset()) { | ||||
|                 CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset); | ||||
| @ -75,7 +77,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|             auto* ty = mem->Type(); | ||||
|             const ast::Type* type = CreateASTTypeFor(ctx, ty); | ||||
| 
 | ||||
|             new_members.Push(ctx.dst->Member(name, type)); | ||||
|             new_members.Push(b.Member(name, type)); | ||||
| 
 | ||||
|             uint32_t size = ty->Size(); | ||||
|             if (ty->Is<sem::Struct>() && str->UsedAs(ast::AddressSpace::kUniform)) { | ||||
| @ -97,8 +99,8 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|         if (offset < struct_size && !has_runtime_sized_array) { | ||||
|             CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset); | ||||
|         } | ||||
|         auto* new_struct = ctx.dst->create<ast::Struct>(ctx.Clone(ast_str->name), | ||||
|                                                         std::move(new_members), utils::Empty); | ||||
|         auto* new_struct = | ||||
|             b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members), utils::Empty); | ||||
|         replaced_structs[ast_str] = new_struct; | ||||
|         return new_struct; | ||||
|     }); | ||||
| @ -131,16 +133,17 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|         auto* arg = ast_call->args.begin(); | ||||
|         for (auto* member : new_struct->members) { | ||||
|             if (padding_members.Contains(member)) { | ||||
|                 new_args.Push(ctx.dst->Expr(0_u)); | ||||
|                 new_args.Push(b.Expr(0_u)); | ||||
|             } else { | ||||
|                 new_args.Push(ctx.Clone(*arg)); | ||||
|                 arg++; | ||||
|             } | ||||
|         } | ||||
|         return ctx.dst->Construct(CreateASTTypeFor(ctx, str), new_args); | ||||
|         return b.Construct(CreateASTTypeFor(ctx, str), new_args); | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -30,14 +30,10 @@ class PadStructs final : public Castable<PadStructs, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~PadStructs() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -13,6 +13,9 @@ | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include "src/tint/transform/promote_initializers_to_let.h" | ||||
| 
 | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/program_builder.h" | ||||
| #include "src/tint/sem/call.h" | ||||
| #include "src/tint/sem/statement.h" | ||||
| @ -27,9 +30,16 @@ PromoteInitializersToLet::PromoteInitializersToLet() = default; | ||||
| 
 | ||||
| PromoteInitializersToLet::~PromoteInitializersToLet() = default; | ||||
| 
 | ||||
| void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src, | ||||
|                                                        const DataMap&, | ||||
|                                                        DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     HoistToDeclBefore hoist_to_decl_before(ctx); | ||||
| 
 | ||||
|     bool any_promoted = false; | ||||
| 
 | ||||
|     // Hoists array and structure initializers to a constant variable, declared
 | ||||
|     // just before the statement of usage.
 | ||||
|     auto promote = [&](const sem::Expression* expr) { | ||||
| @ -59,14 +69,15 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         any_promoted = true; | ||||
|         return hoist_to_decl_before.Add(expr, expr->Declaration(), true); | ||||
|     }; | ||||
| 
 | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         bool ok = Switch( | ||||
|             node,  //
 | ||||
|             [&](const ast::CallExpression* expr) { | ||||
|                 if (auto* sem = ctx.src->Sem().Get(expr)) { | ||||
|                 if (auto* sem = src->Sem().Get(expr)) { | ||||
|                     auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>(); | ||||
|                     if (ctor->Target()->Is<sem::TypeInitializer>()) { | ||||
|                         return promote(sem); | ||||
| @ -75,7 +86,7 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) | ||||
|                 return true; | ||||
|             }, | ||||
|             [&](const ast::IdentifierExpression* expr) { | ||||
|                 if (auto* sem = ctx.src->Sem().Get(expr)) { | ||||
|                 if (auto* sem = src->Sem().Get(expr)) { | ||||
|                     if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) { | ||||
|                         // Identifier resolves to a variable
 | ||||
|                         if (auto* stmt = user->Stmt()) { | ||||
| @ -96,13 +107,17 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) | ||||
|                 return true; | ||||
|             }, | ||||
|             [&](Default) { return true; }); | ||||
| 
 | ||||
|         if (!ok) { | ||||
|             return; | ||||
|             return Program(std::move(b)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (!any_promoted) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -33,14 +33,10 @@ class PromoteInitializersToLet final : public Castable<PromoteInitializersToLet, | ||||
|     /// Destructor
 | ||||
|     ~PromoteInitializersToLet() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -53,34 +53,36 @@ class StateBase { | ||||
| // to else {if}s so that the next transform, DecomposeSideEffects, can insert
 | ||||
| // hoisted expressions above their current location.
 | ||||
| struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> { | ||||
|     class State; | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override; | ||||
|     ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| class SimplifySideEffectStatements::State : public StateBase { | ||||
|     HoistToDeclBefore hoist_to_decl_before; | ||||
| Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src, | ||||
|                                                            const DataMap&, | ||||
|                                                            DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|   public: | ||||
|     explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {} | ||||
|     bool made_changes = false; | ||||
| 
 | ||||
|     void Run() { | ||||
|         for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|             if (auto* expr = node->As<ast::Expression>()) { | ||||
|                 auto* sem_expr = sem.Get(expr); | ||||
|                 if (!sem_expr || !sem_expr->HasSideEffects()) { | ||||
|                     continue; | ||||
|                 } | ||||
| 
 | ||||
|                 hoist_to_decl_before.Prepare(sem_expr); | ||||
|     HoistToDeclBefore hoist_to_decl_before(ctx); | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|         if (auto* expr = node->As<ast::Expression>()) { | ||||
|             auto* sem_expr = src->Sem().Get(expr); | ||||
|             if (!sem_expr || !sem_expr->HasSideEffects()) { | ||||
|                 continue; | ||||
|             } | ||||
|         } | ||||
|         ctx.Clone(); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State state(ctx); | ||||
|     state.Run(); | ||||
|             hoist_to_decl_before.Prepare(sem_expr); | ||||
|             made_changes = true; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (!made_changes) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| // Decomposes side-effecting expressions to ensure order of evaluation. This
 | ||||
| @ -89,7 +91,7 @@ void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMa | ||||
| struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> { | ||||
|     class CollectHoistsState; | ||||
|     class DecomposeState; | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override; | ||||
|     ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| // CollectHoistsState traverses the AST top-down, identifying which expressions
 | ||||
| @ -667,12 +669,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase { | ||||
|             } | ||||
|             return nullptr; | ||||
|         }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src, | ||||
|                                                    const DataMap&, | ||||
|                                                    DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     // First collect side-effecting expressions to hoist
 | ||||
|     CollectHoistsState collect_hoists_state{ctx}; | ||||
|     auto to_hoist = collect_hoists_state.Run(); | ||||
| @ -680,6 +685,9 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
|     // Now decompose these expressions
 | ||||
|     DecomposeState decompose_state{ctx, std::move(to_hoist)}; | ||||
|     decompose_state.Run(); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| @ -687,13 +695,13 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons | ||||
| PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default; | ||||
| PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default; | ||||
| 
 | ||||
| Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const { | ||||
| Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src, | ||||
|                                                        const DataMap& inputs, | ||||
|                                                        DataMap& outputs) const { | ||||
|     transform::Manager manager; | ||||
|     manager.Add<SimplifySideEffectStatements>(); | ||||
|     manager.Add<DecomposeSideEffects>(); | ||||
| 
 | ||||
|     auto output = manager.Run(program, data); | ||||
|     return output; | ||||
|     return manager.Apply(src, inputs, outputs); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -31,12 +31,10 @@ class PromoteSideEffectsToDecl final : public Castable<PromoteSideEffectsToDecl, | ||||
|     /// Destructor
 | ||||
|     ~PromoteSideEffectsToDecl() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform on `program`, returning the transformation result.
 | ||||
|     /// @param program the source program to transform
 | ||||
|     /// @param data optional extra transform-specific data
 | ||||
|     /// @returns the transformation result
 | ||||
|     Output Run(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -32,53 +32,19 @@ | ||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveContinueInSwitch); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| class State { | ||||
|   private: | ||||
|     CloneContext& ctx; | ||||
|     ProgramBuilder& b; | ||||
|     const sem::Info& sem; | ||||
| 
 | ||||
|     // Map of switch statement to 'tint_continue' variable.
 | ||||
|     std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name; | ||||
| 
 | ||||
|     // If `cont` is within a switch statement within a loop, returns a pointer to
 | ||||
|     // that switch statement.
 | ||||
|     static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, | ||||
|                                                              const ast::ContinueStatement* cont) { | ||||
|         // Find whether first parent is a switch or a loop
 | ||||
|         auto* sem_stmt = sem.Get(cont); | ||||
|         auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement, | ||||
|                                                      sem::ForLoopStatement, sem::WhileStatement>(); | ||||
|         if (!sem_parent) { | ||||
|             return nullptr; | ||||
|         } | ||||
|         return sem_parent->Declaration()->As<ast::SwitchStatement>(); | ||||
|     } | ||||
| 
 | ||||
|   public: | ||||
| /// PIMPL state for the transform
 | ||||
| struct RemoveContinueInSwitch::State { | ||||
|     /// Constructor
 | ||||
|     /// @param ctx_in the context
 | ||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} | ||||
| 
 | ||||
|     /// Returns true if this transform should be run for the given program
 | ||||
|     static bool ShouldRun(const Program* program) { | ||||
|         for (auto* node : program->ASTNodes().Objects()) { | ||||
|             auto* stmt = node->As<ast::ContinueStatement>(); | ||||
|             if (!stmt) { | ||||
|                 continue; | ||||
|             } | ||||
|             if (GetParentSwitchInLoop(program->Sem(), stmt)) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|         return false; | ||||
|     } | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|         for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         bool made_changes = false; | ||||
| 
 | ||||
|         for (auto* node : src->ASTNodes().Objects()) { | ||||
|             auto* cont = node->As<ast::ContinueStatement>(); | ||||
|             if (!cont) { | ||||
|                 continue; | ||||
| @ -90,6 +56,8 @@ class State { | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|             made_changes = true; | ||||
| 
 | ||||
|             auto cont_var_name = | ||||
|                 tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() { | ||||
|                     // Create and insert 'var tint_continue : bool = false;' before the
 | ||||
| @ -116,22 +84,50 @@ class State { | ||||
|             ctx.Replace(cont, new_stmt); | ||||
|         } | ||||
| 
 | ||||
|         if (!made_changes) { | ||||
|             return SkipTransform; | ||||
|         } | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
|     /// Alias to src->sem
 | ||||
|     const sem::Info& sem = src->Sem(); | ||||
| 
 | ||||
|     // Map of switch statement to 'tint_continue' variable.
 | ||||
|     std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name; | ||||
| 
 | ||||
|     // If `cont` is within a switch statement within a loop, returns a pointer to
 | ||||
|     // that switch statement.
 | ||||
|     static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, | ||||
|                                                              const ast::ContinueStatement* cont) { | ||||
|         // Find whether first parent is a switch or a loop
 | ||||
|         auto* sem_stmt = sem.Get(cont); | ||||
|         auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement, | ||||
|                                                      sem::ForLoopStatement, sem::WhileStatement>(); | ||||
|         if (!sem_parent) { | ||||
|             return nullptr; | ||||
|         } | ||||
|         return sem_parent->Declaration()->As<ast::SwitchStatement>(); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| RemoveContinueInSwitch::RemoveContinueInSwitch() = default; | ||||
| RemoveContinueInSwitch::~RemoveContinueInSwitch() = default; | ||||
| 
 | ||||
| bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const { | ||||
|     return State::ShouldRun(program); | ||||
| } | ||||
| 
 | ||||
| void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State state(ctx); | ||||
|     state.Run(); | ||||
| Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src, | ||||
|                                                      const DataMap&, | ||||
|                                                      DataMap&) const { | ||||
|     State state(src); | ||||
|     return state.Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -31,19 +31,13 @@ class RemoveContinueInSwitch final : public Castable<RemoveContinueInSwitch, Tra | ||||
|     /// Destructor
 | ||||
|     ~RemoveContinueInSwitch() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -41,34 +41,25 @@ RemovePhonies::RemovePhonies() = default; | ||||
| 
 | ||||
| RemovePhonies::~RemovePhonies() = default; | ||||
| 
 | ||||
| bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (node->Is<ast::PhonyExpression>()) { | ||||
|             return true; | ||||
|         } | ||||
|         if (auto* stmt = node->As<ast::CallStatement>()) { | ||||
|             if (program->Sem().Get(stmt->expr)->ConstantValue() != nullptr) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
| void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     auto& sem = ctx.src->Sem(); | ||||
|     auto& sem = src->Sem(); | ||||
| 
 | ||||
|     std::unordered_map<SinkSignature, Symbol, utils::Hasher<SinkSignature>> sinks; | ||||
|     utils::Hashmap<SinkSignature, Symbol, 8, utils::Hasher<SinkSignature>> sinks; | ||||
| 
 | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     bool made_changes = false; | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         Switch( | ||||
|             node, | ||||
|             [&](const ast::AssignmentStatement* stmt) { | ||||
|                 if (stmt->lhs->Is<ast::PhonyExpression>()) { | ||||
|                     made_changes = true; | ||||
| 
 | ||||
|                     std::vector<const ast::Expression*> side_effects; | ||||
|                     if (!ast::TraverseExpressions( | ||||
|                             stmt->rhs, ctx.dst->Diagnostics(), | ||||
|                             [&](const ast::CallExpression* expr) { | ||||
|                             stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) { | ||||
|                                 // ast::CallExpression may map to a function or builtin call
 | ||||
|                                 // (both may have side-effects), or a type initializer or
 | ||||
|                                 // type conversion (both do not have side effects).
 | ||||
| @ -100,8 +91,7 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|                         if (auto* call = side_effects[0]->As<ast::CallExpression>()) { | ||||
|                             // Phony assignment with single call side effect.
 | ||||
|                             // Replace phony assignment with call.
 | ||||
|                             ctx.Replace(stmt, | ||||
|                                         [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); }); | ||||
|                             ctx.Replace(stmt, [&, call] { return b.CallStmt(ctx.Clone(call)); }); | ||||
|                             return; | ||||
|                         } | ||||
|                     } | ||||
| @ -114,22 +104,21 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|                         for (auto* arg : side_effects) { | ||||
|                             sig.push_back(sem.Get(arg)->Type()->UnwrapRef()); | ||||
|                         } | ||||
|                         auto sink = utils::GetOrCreate(sinks, sig, [&] { | ||||
|                             auto name = ctx.dst->Symbols().New("phony_sink"); | ||||
|                         auto sink = sinks.GetOrCreate(sig, [&] { | ||||
|                             auto name = b.Symbols().New("phony_sink"); | ||||
|                             utils::Vector<const ast::Parameter*, 8> params; | ||||
|                             for (auto* ty : sig) { | ||||
|                                 auto* ast_ty = CreateASTTypeFor(ctx, ty); | ||||
|                                 params.Push( | ||||
|                                     ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty)); | ||||
|                                 params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty)); | ||||
|                             } | ||||
|                             ctx.dst->Func(name, params, ctx.dst->ty.void_(), {}); | ||||
|                             b.Func(name, params, b.ty.void_(), {}); | ||||
|                             return name; | ||||
|                         }); | ||||
|                         utils::Vector<const ast::Expression*, 8> args; | ||||
|                         for (auto* arg : side_effects) { | ||||
|                             args.Push(ctx.Clone(arg)); | ||||
|                         } | ||||
|                         return ctx.dst->CallStmt(ctx.dst->Call(sink, args)); | ||||
|                         return b.CallStmt(b.Call(sink, args)); | ||||
|                     }); | ||||
|                 } | ||||
|             }, | ||||
| @ -138,12 +127,18 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|                 // TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
 | ||||
|                 auto* sem_expr = sem.Get(stmt->expr); | ||||
|                 if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) { | ||||
|                     made_changes = true; | ||||
|                     ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt); | ||||
|                 } | ||||
|             }); | ||||
|     } | ||||
| 
 | ||||
|     if (!made_changes) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -33,19 +33,10 @@ class RemovePhonies final : public Castable<RemovePhonies, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~RemovePhonies() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -36,27 +36,28 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default; | ||||
| 
 | ||||
| RemoveUnreachableStatements::~RemoveUnreachableStatements() = default; | ||||
| 
 | ||||
| bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* stmt = program->Sem().Get<sem::Statement>(node)) { | ||||
| Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src, | ||||
|                                                           const DataMap&, | ||||
|                                                           DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     bool made_changes = false; | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* stmt = src->Sem().Get<sem::Statement>(node)) { | ||||
|             if (!stmt->IsReachable()) { | ||||
|                 return true; | ||||
|                 RemoveStatement(ctx, stmt->Declaration()); | ||||
|                 made_changes = true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|         if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) { | ||||
|             if (!stmt->IsReachable()) { | ||||
|                 RemoveStatement(ctx, stmt->Declaration()); | ||||
|             } | ||||
|         } | ||||
|     if (!made_changes) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -32,19 +32,10 @@ class RemoveUnreachableStatements final : public Castable<RemoveUnreachableState | ||||
|     /// Destructor
 | ||||
|     ~RemoveUnreachableStatements() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -1252,39 +1252,31 @@ Renamer::Config::~Config() = default; | ||||
| Renamer::Renamer() = default; | ||||
| Renamer::~Renamer() = default; | ||||
| 
 | ||||
| Output Renamer::Run(const Program* in, const DataMap& inputs) const { | ||||
|     ProgramBuilder out; | ||||
|     // Disable auto-cloning of symbols, since we want to rename them.
 | ||||
|     CloneContext ctx(&out, in, false); | ||||
| Transform::ApplyResult Renamer::Apply(const Program* src, | ||||
|                                       const DataMap& inputs, | ||||
|                                       DataMap& outputs) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ false}; | ||||
| 
 | ||||
|     // Swizzles, builtin calls and builtin structure members need to keep their
 | ||||
|     // symbols preserved.
 | ||||
|     std::unordered_set<const ast::IdentifierExpression*> preserve; | ||||
|     for (auto* node : in->ASTNodes().Objects()) { | ||||
|     utils::Hashset<const ast::IdentifierExpression*, 8> preserve; | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* member = node->As<ast::MemberAccessorExpression>()) { | ||||
|             auto* sem = in->Sem().Get(member); | ||||
|             if (!sem) { | ||||
|                 TINT_ICE(Transform, out.Diagnostics()) | ||||
|                     << "MemberAccessorExpression has no semantic info"; | ||||
|                 continue; | ||||
|             } | ||||
|             auto* sem = src->Sem().Get(member); | ||||
|             if (sem->Is<sem::Swizzle>()) { | ||||
|                 preserve.emplace(member->member); | ||||
|             } else if (auto* str_expr = in->Sem().Get(member->structure)) { | ||||
|                 preserve.Add(member->member); | ||||
|             } else if (auto* str_expr = src->Sem().Get(member->structure)) { | ||||
|                 if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) { | ||||
|                     if (ty->Declaration() == nullptr) {  // Builtin structure
 | ||||
|                         preserve.emplace(member->member); | ||||
|                         preserve.Add(member->member); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } else if (auto* call = node->As<ast::CallExpression>()) { | ||||
|             auto* sem = in->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|             if (!sem) { | ||||
|                 TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info"; | ||||
|                 continue; | ||||
|             } | ||||
|             auto* sem = src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|             if (sem->Target()->Is<sem::Builtin>()) { | ||||
|                 preserve.emplace(call->target.name); | ||||
|                 preserve.Add(call->target.name); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @ -1300,7 +1292,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const { | ||||
|     } | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](Symbol sym_in) { | ||||
|         auto name_in = ctx.src->Symbols().NameFor(sym_in); | ||||
|         auto name_in = src->Symbols().NameFor(sym_in); | ||||
|         if (preserve_unicode || text::utf8::IsASCII(name_in)) { | ||||
|             switch (target) { | ||||
|                 case Target::kAll: | ||||
| @ -1343,17 +1335,20 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const { | ||||
|     }); | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* { | ||||
|         if (preserve.count(ident)) { | ||||
|         if (preserve.Contains(ident)) { | ||||
|             auto sym_in = ident->symbol; | ||||
|             auto str = in->Symbols().NameFor(sym_in); | ||||
|             auto sym_out = out.Symbols().Register(str); | ||||
|             auto str = src->Symbols().NameFor(sym_in); | ||||
|             auto sym_out = b.Symbols().Register(str); | ||||
|             return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out); | ||||
|         } | ||||
|         return nullptr;  // Clone ident. Uses the symbol remapping above.
 | ||||
|     }); | ||||
|     ctx.Clone(); | ||||
| 
 | ||||
|     return Output(Program(std::move(out)), std::make_unique<Data>(std::move(remappings))); | ||||
|     ctx.Clone();  // Must come before the std::move()
 | ||||
| 
 | ||||
|     outputs.Add<Data>(std::move(remappings)); | ||||
| 
 | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -85,11 +85,10 @@ class Renamer final : public Castable<Renamer, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~Renamer() override; | ||||
| 
 | ||||
|     /// Runs the transform on `program`, returning the transformation result.
 | ||||
|     /// @param program the source program to transform
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns the transformation result
 | ||||
|     Output Run(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -33,36 +33,48 @@ using namespace tint::number_suffixes;  // NOLINT | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// State holds the current transform state
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct Robustness::State { | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     /// Constructor
 | ||||
|     /// @param program the source program
 | ||||
|     /// @param omitted the omitted address spaces
 | ||||
|     State(const Program* program, std::unordered_set<ast::AddressSpace>&& omitted) | ||||
|         : src(program), omitted_address_spaces(std::move(omitted)) {} | ||||
| 
 | ||||
|     /// Set of address spacees to not apply the transform to
 | ||||
|     std::unordered_set<ast::AddressSpace> omitted_classes; | ||||
| 
 | ||||
|     /// Applies the transformation state to `ctx`.
 | ||||
|     void Transform() { | ||||
|     /// Runs the transform
 | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); }); | ||||
|         ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     /// Set of address spaces to not apply the transform to
 | ||||
|     std::unordered_set<ast::AddressSpace> omitted_address_spaces; | ||||
| 
 | ||||
|     /// Apply bounds clamping to array, vector and matrix indexing
 | ||||
|     /// @param expr the array, vector or matrix index expression
 | ||||
|     /// @return the clamped replacement expression, or nullptr if `expr` should be cloned without
 | ||||
|     /// changes.
 | ||||
|     const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) { | ||||
|         auto* sem = | ||||
|             ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>(); | ||||
|         auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>(); | ||||
|         auto* ret_type = sem->Type(); | ||||
| 
 | ||||
|         auto* ref = ret_type->As<sem::Reference>(); | ||||
|         if (ref && omitted_classes.count(ref->AddressSpace()) != 0) { | ||||
|         if (ref && omitted_address_spaces.count(ref->AddressSpace()) != 0) { | ||||
|             return nullptr; | ||||
|         } | ||||
| 
 | ||||
|         ProgramBuilder& b = *ctx.dst; | ||||
| 
 | ||||
|         // idx return the cloned index expression, as a u32.
 | ||||
|         auto idx = [&]() -> const ast::Expression* { | ||||
|             auto* i = ctx.Clone(expr->index); | ||||
| @ -109,8 +121,8 @@ struct Robustness::State { | ||||
|                 } else { | ||||
|                     // Note: Don't be tempted to use the array override variable as an expression
 | ||||
|                     // here, the name might be shadowed!
 | ||||
|                     ctx.dst->Diagnostics().add_error(diag::System::Transform, | ||||
|                                                      sem::Array::kErrExpectedConstantCount); | ||||
|                     b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                               sem::Array::kErrExpectedConstantCount); | ||||
|                     return nullptr; | ||||
|                 } | ||||
| 
 | ||||
| @ -119,7 +131,7 @@ struct Robustness::State { | ||||
|             [&](Default) { | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "unhandled object type in robustness of array index: " | ||||
|                     << ctx.src->FriendlyName(ret_type->UnwrapRef()); | ||||
|                     << src->FriendlyName(ret_type->UnwrapRef()); | ||||
|                 return nullptr; | ||||
|             }); | ||||
| 
 | ||||
| @ -127,9 +139,9 @@ struct Robustness::State { | ||||
|             return nullptr;  // Clamping not needed
 | ||||
|         } | ||||
| 
 | ||||
|         auto src = ctx.Clone(expr->source); | ||||
|         auto* obj = ctx.Clone(expr->object); | ||||
|         return b.IndexAccessor(src, obj, clamped_idx); | ||||
|         auto idx_src = ctx.Clone(expr->source); | ||||
|         auto* idx_obj = ctx.Clone(expr->object); | ||||
|         return b.IndexAccessor(idx_src, idx_obj, clamped_idx); | ||||
|     } | ||||
| 
 | ||||
|     /// @param type builtin type
 | ||||
| @ -145,15 +157,13 @@ struct Robustness::State { | ||||
|     /// @return the clamped replacement call expression, or nullptr if `expr`
 | ||||
|     /// should be cloned without changes.
 | ||||
|     const ast::CallExpression* Transform(const ast::CallExpression* expr) { | ||||
|         auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|         auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|         auto* call_target = call->Target(); | ||||
|         auto* builtin = call_target->As<sem::Builtin>(); | ||||
|         if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) { | ||||
|             return nullptr;  // No transform, just clone.
 | ||||
|         } | ||||
| 
 | ||||
|         ProgramBuilder& b = *ctx.dst; | ||||
| 
 | ||||
|         // Indices of the mandatory texture and coords parameters, and the optional
 | ||||
|         // array and level parameters.
 | ||||
|         auto& signature = builtin->Signature(); | ||||
| @ -261,7 +271,7 @@ struct Robustness::State { | ||||
|         // Clamp the level argument, if provided
 | ||||
|         if (level_idx >= 0) { | ||||
|             auto* arg = expr->args[static_cast<size_t>(level_idx)]; | ||||
|             ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0_a)); | ||||
|             ctx.Replace(arg, level_arg ? level_arg() : b.Expr(0_a)); | ||||
|         } | ||||
| 
 | ||||
|         return nullptr;  // Clone, which will use the argument replacements above.
 | ||||
| @ -276,28 +286,27 @@ Robustness::Config& Robustness::Config::operator=(const Config&) = default; | ||||
| Robustness::Robustness() = default; | ||||
| Robustness::~Robustness() = default; | ||||
| 
 | ||||
| void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
| Transform::ApplyResult Robustness::Apply(const Program* src, | ||||
|                                          const DataMap& inputs, | ||||
|                                          DataMap&) const { | ||||
|     Config cfg; | ||||
|     if (auto* cfg_data = inputs.Get<Config>()) { | ||||
|         cfg = *cfg_data; | ||||
|     } | ||||
| 
 | ||||
|     std::unordered_set<ast::AddressSpace> omitted_classes; | ||||
|     for (auto sc : cfg.omitted_classes) { | ||||
|     std::unordered_set<ast::AddressSpace> omitted_address_spaces; | ||||
|     for (auto sc : cfg.omitted_address_spaces) { | ||||
|         switch (sc) { | ||||
|             case AddressSpace::kUniform: | ||||
|                 omitted_classes.insert(ast::AddressSpace::kUniform); | ||||
|                 omitted_address_spaces.insert(ast::AddressSpace::kUniform); | ||||
|                 break; | ||||
|             case AddressSpace::kStorage: | ||||
|                 omitted_classes.insert(ast::AddressSpace::kStorage); | ||||
|                 omitted_address_spaces.insert(ast::AddressSpace::kStorage); | ||||
|                 break; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     State state{ctx, std::move(omitted_classes)}; | ||||
| 
 | ||||
|     state.Transform(); | ||||
|     ctx.Clone(); | ||||
|     return State{src, std::move(omitted_address_spaces)}.Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -54,9 +54,9 @@ class Robustness final : public Castable<Robustness, Transform> { | ||||
|         /// @returns this Config
 | ||||
|         Config& operator=(const Config&); | ||||
| 
 | ||||
|         /// Address spacees to omit from apply the transform to.
 | ||||
|         /// Address spaces to omit from apply the transform to.
 | ||||
|         /// This allows for optimizing on hardware that provide safe accesses.
 | ||||
|         std::unordered_set<AddressSpace> omitted_classes; | ||||
|         std::unordered_set<AddressSpace> omitted_address_spaces; | ||||
|     }; | ||||
| 
 | ||||
|     /// Constructor
 | ||||
| @ -64,14 +64,10 @@ class Robustness final : public Castable<Robustness, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~Robustness() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
|  | ||||
| @ -1274,7 +1274,7 @@ fn f() { | ||||
| )"; | ||||
| 
 | ||||
|     Robustness::Config cfg; | ||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage); | ||||
|     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage); | ||||
| 
 | ||||
|     DataMap data; | ||||
|     data.Add<Robustness::Config>(cfg); | ||||
| @ -1325,7 +1325,7 @@ fn f() { | ||||
| )"; | ||||
| 
 | ||||
|     Robustness::Config cfg; | ||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform); | ||||
|     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform); | ||||
| 
 | ||||
|     DataMap data; | ||||
|     data.Add<Robustness::Config>(cfg); | ||||
| @ -1376,8 +1376,8 @@ fn f() { | ||||
| )"; | ||||
| 
 | ||||
|     Robustness::Config cfg; | ||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage); | ||||
|     cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform); | ||||
|     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage); | ||||
|     cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform); | ||||
| 
 | ||||
|     DataMap data; | ||||
|     data.Add<Robustness::Config>(cfg); | ||||
|  | ||||
| @ -45,14 +45,18 @@ struct PointerOp { | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// The PIMPL state for the SimplifyPointers transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct SimplifyPointers::State { | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     /// Constructor
 | ||||
|     /// @param context the clone context
 | ||||
|     explicit State(CloneContext& context) : ctx(context) {} | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Traverses the expression `expr` looking for non-literal array indexing
 | ||||
|     /// expressions that would affect the computed address of a pointer
 | ||||
| @ -120,10 +124,11 @@ struct SimplifyPointers::State { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Performs the transformation
 | ||||
|     void Run() { | ||||
|     /// Runs the transform
 | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         // A map of saved expressions to their saved variable name
 | ||||
|         std::unordered_map<const ast::Expression*, Symbol> saved_vars; | ||||
|         utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars; | ||||
| 
 | ||||
|         // Register the ast::Expression transform handler.
 | ||||
|         // This performs two different transformations:
 | ||||
| @ -135,9 +140,8 @@ struct SimplifyPointers::State { | ||||
|         // variable identifier.
 | ||||
|         ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { | ||||
|             // Look to see if we need to swap this Expression with a saved variable.
 | ||||
|             auto it = saved_vars.find(expr); | ||||
|             if (it != saved_vars.end()) { | ||||
|                 return ctx.dst->Expr(it->second); | ||||
|             if (auto* saved_var = saved_vars.Find(expr)) { | ||||
|                 return ctx.dst->Expr(*saved_var); | ||||
|             } | ||||
| 
 | ||||
|             // Reduce the expression, folding away chains of address-of / indirections
 | ||||
| @ -174,7 +178,7 @@ struct SimplifyPointers::State { | ||||
| 
 | ||||
|                 // Scan the initializer expression for array index expressions that need
 | ||||
|                 // to be hoist to temporary "saved" variables.
 | ||||
|                 std::vector<const ast::VariableDeclStatement*> saved; | ||||
|                 utils::Vector<const ast::VariableDeclStatement*, 8> saved; | ||||
|                 CollectSavedArrayIndices( | ||||
|                     var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { | ||||
|                         // We have a sub-expression that needs to be saved.
 | ||||
| @ -182,18 +186,18 @@ struct SimplifyPointers::State { | ||||
|                         auto saved_name = ctx.dst->Symbols().New( | ||||
|                             ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); | ||||
|                         auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); | ||||
|                         saved.emplace_back(decl); | ||||
|                         saved.Push(decl); | ||||
|                         // Record the substitution of `idx_expr` to the saved variable
 | ||||
|                         // with the symbol `saved_name`. This will be used by the
 | ||||
|                         // ReplaceAll() handler above.
 | ||||
|                         saved_vars.emplace(idx_expr, saved_name); | ||||
|                         saved_vars.Add(idx_expr, saved_name); | ||||
|                     }); | ||||
| 
 | ||||
|                 // Find the place to insert the saved declarations.
 | ||||
|                 // Special care needs to be made for lets declared as the initializer
 | ||||
|                 // part of for-loops. In this case the block will hold the for-loop
 | ||||
|                 // statement, not the let.
 | ||||
|                 if (!saved.empty()) { | ||||
|                 if (!saved.IsEmpty()) { | ||||
|                     auto* stmt = ctx.src->Sem().Get(let); | ||||
|                     auto* block = stmt->Block(); | ||||
|                     // Find the statement owned by the block (either the let decl or a
 | ||||
| @ -219,7 +223,9 @@ struct SimplifyPointers::State { | ||||
|                 RemoveStatement(ctx, let); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| @ -227,8 +233,8 @@ SimplifyPointers::SimplifyPointers() = default; | ||||
| 
 | ||||
| SimplifyPointers::~SimplifyPointers() = default; | ||||
| 
 | ||||
| void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State(ctx).Run(); | ||||
| Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     return State(src).Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -39,16 +39,13 @@ class SimplifyPointers final : public Castable<SimplifyPointers, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~SimplifyPointers() override; | ||||
| 
 | ||||
|   protected: | ||||
|     struct State; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -30,33 +30,37 @@ SingleEntryPoint::SingleEntryPoint() = default; | ||||
| 
 | ||||
| SingleEntryPoint::~SingleEntryPoint() = default; | ||||
| 
 | ||||
| void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
| Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, | ||||
|                                                const DataMap& inputs, | ||||
|                                                DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     auto* cfg = inputs.Get<Config>(); | ||||
|     if (cfg == nullptr) { | ||||
|         ctx.dst->Diagnostics().add_error( | ||||
|             diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); | ||||
| 
 | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                   "missing transform data for " + std::string(TypeInfo().name)); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     // Find the target entry point.
 | ||||
|     const ast::Function* entry_point = nullptr; | ||||
|     for (auto* f : ctx.src->AST().Functions()) { | ||||
|     for (auto* f : src->AST().Functions()) { | ||||
|         if (!f->IsEntryPoint()) { | ||||
|             continue; | ||||
|         } | ||||
|         if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) { | ||||
|         if (src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) { | ||||
|             entry_point = f; | ||||
|             break; | ||||
|         } | ||||
|     } | ||||
|     if (entry_point == nullptr) { | ||||
|         ctx.dst->Diagnostics().add_error(diag::System::Transform, | ||||
|                                          "entry point '" + cfg->entry_point_name + "' not found"); | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                   "entry point '" + cfg->entry_point_name + "' not found"); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     auto& sem = ctx.src->Sem(); | ||||
|     auto& sem = src->Sem(); | ||||
| 
 | ||||
|     // Build set of referenced module-scope variables for faster lookups later.
 | ||||
|     std::unordered_set<const ast::Variable*> referenced_vars; | ||||
| @ -66,12 +70,12 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c | ||||
| 
 | ||||
|     // Clone any module-scope variables, types, and functions that are statically referenced by the
 | ||||
|     // target entry point.
 | ||||
|     for (auto* decl : ctx.src->AST().GlobalDeclarations()) { | ||||
|     for (auto* decl : src->AST().GlobalDeclarations()) { | ||||
|         Switch( | ||||
|             decl,  //
 | ||||
|             [&](const ast::TypeDecl* ty) { | ||||
|                 // TODO(jrprice): Strip unused types.
 | ||||
|                 ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); | ||||
|                 b.AST().AddTypeDecl(ctx.Clone(ty)); | ||||
|             }, | ||||
|             [&](const ast::Override* override) { | ||||
|                 if (referenced_vars.count(override)) { | ||||
| @ -80,37 +84,39 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c | ||||
|                         // so that its allocated ID so that it won't be affected by other
 | ||||
|                         // stripped away overrides
 | ||||
|                         auto* global = sem.Get(override); | ||||
|                         const auto* id = ctx.dst->Id(global->OverrideId()); | ||||
|                         const auto* id = b.Id(global->OverrideId()); | ||||
|                         ctx.InsertFront(override->attributes, id); | ||||
|                     } | ||||
|                     ctx.dst->AST().AddGlobalVariable(ctx.Clone(override)); | ||||
|                     b.AST().AddGlobalVariable(ctx.Clone(override)); | ||||
|                 } | ||||
|             }, | ||||
|             [&](const ast::Var* var) { | ||||
|                 if (referenced_vars.count(var)) { | ||||
|                     ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); | ||||
|                     b.AST().AddGlobalVariable(ctx.Clone(var)); | ||||
|                 } | ||||
|             }, | ||||
|             [&](const ast::Const* c) { | ||||
|                 // Always keep 'const' declarations, as these can be used by attributes and array
 | ||||
|                 // sizes, which are not tracked as transitively used by functions. They also don't
 | ||||
|                 // typically get emitted by the backend unless they're actually used.
 | ||||
|                 ctx.dst->AST().AddGlobalVariable(ctx.Clone(c)); | ||||
|                 b.AST().AddGlobalVariable(ctx.Clone(c)); | ||||
|             }, | ||||
|             [&](const ast::Function* func) { | ||||
|                 if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { | ||||
|                     ctx.dst->AST().AddFunction(ctx.Clone(func)); | ||||
|                     b.AST().AddFunction(ctx.Clone(func)); | ||||
|                 } | ||||
|             }, | ||||
|             [&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); }, | ||||
|             [&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); }, | ||||
|             [&](Default) { | ||||
|                 TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | ||||
|                 TINT_UNREACHABLE(Transform, b.Diagnostics()) | ||||
|                     << "unhandled global declaration: " << decl->TypeInfo().name; | ||||
|             }); | ||||
|     } | ||||
| 
 | ||||
|     // Clone the entry point.
 | ||||
|     ctx.dst->AST().AddFunction(ctx.Clone(entry_point)); | ||||
|     b.AST().AddFunction(ctx.Clone(entry_point)); | ||||
| 
 | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {} | ||||
|  | ||||
| @ -53,14 +53,10 @@ class SingleEntryPoint final : public Castable<SingleEntryPoint, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~SingleEntryPoint() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -37,7 +37,7 @@ namespace tint::transform { | ||||
| 
 | ||||
| using namespace tint::number_suffixes;  // NOLINT
 | ||||
| 
 | ||||
| /// Private implementation of transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct SpirvAtomic::State { | ||||
|   private: | ||||
|     /// A struct that has been forked because a subset of members were made atomic.
 | ||||
| @ -46,19 +46,24 @@ struct SpirvAtomic::State { | ||||
|         std::unordered_set<size_t> atomic_members; | ||||
|     }; | ||||
| 
 | ||||
|     CloneContext& ctx; | ||||
|     ProgramBuilder& b = *ctx.dst; | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
|     std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs; | ||||
|     std::unordered_set<const sem::Variable*> atomic_variables; | ||||
|     utils::UniqueVector<const sem::Expression*, 8> atomic_expressions; | ||||
| 
 | ||||
|   public: | ||||
|     /// Constructor
 | ||||
|     /// @param c the clone context
 | ||||
|     explicit State(CloneContext& c) : ctx(c) {} | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
 | ||||
|         // for atomic builtin calls.
 | ||||
|         for (auto* fn : ctx.src->AST().Functions()) { | ||||
| @ -102,6 +107,10 @@ struct SpirvAtomic::State { | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if (atomic_expressions.IsEmpty()) { | ||||
|             return SkipTransform; | ||||
|         } | ||||
| 
 | ||||
|         // Transform all variables and structure members that were used in atomic operations as
 | ||||
|         // atomic types. This propagates up originating expression chains.
 | ||||
|         ProcessAtomicExpressions(); | ||||
| @ -143,6 +152,7 @@ struct SpirvAtomic::State { | ||||
|         ReplaceLoadsAndStores(); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
| @ -297,17 +307,8 @@ const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const { | ||||
|                                                           ctx->dst->AllocateNodeID(), builtin); | ||||
| } | ||||
| 
 | ||||
| bool SpirvAtomic::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     for (auto* fn : program->AST().Functions()) { | ||||
|         if (ast::HasAttribute<Stub>(fn->attributes)) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void SpirvAtomic::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State{ctx}.Run(); | ||||
| Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     return State{src}.Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -63,21 +63,13 @@ class SpirvAtomic final : public Castable<SpirvAtomic, Transform> { | ||||
|         const sem::BuiltinType builtin; | ||||
|     }; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   protected: | ||||
|   private: | ||||
|     struct State; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -77,14 +77,20 @@ struct Hasher<DynamicIndex> { | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// The PIMPL state for the Std140 transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct Std140::State { | ||||
|     /// Constructor
 | ||||
|     /// @param c the CloneContext
 | ||||
|     explicit State(CloneContext& c) : ctx(c) {} | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         if (!ShouldRun()) { | ||||
|             // Transform is not required
 | ||||
|             return SkipTransform; | ||||
|         } | ||||
| 
 | ||||
|         // Begin by creating forked types for any type that is used as a uniform buffer, that
 | ||||
|         // either directly or transitively contains a matrix that needs splitting for std140 layout.
 | ||||
|         ForkTypes(); | ||||
| @ -116,11 +122,11 @@ struct Std140::State { | ||||
|         }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     /// @param program the program to inspect
 | ||||
|     static bool ShouldRun(const Program* program) { | ||||
|     bool ShouldRun() const { | ||||
|         // Returns true if the type needs to be forked for std140 usage.
 | ||||
|         auto needs_fork = [&](const sem::Type* ty) { | ||||
|             while (auto* arr = ty->As<sem::Array>()) { | ||||
| @ -135,7 +141,7 @@ struct Std140::State { | ||||
|         }; | ||||
| 
 | ||||
|         // Scan structures for members that need forking
 | ||||
|         for (auto* ty : program->Types()) { | ||||
|         for (auto* ty : src->Types()) { | ||||
|             if (auto* str = ty->As<sem::Struct>()) { | ||||
|                 if (str->UsedAs(ast::AddressSpace::kUniform)) { | ||||
|                     for (auto* member : str->Members()) { | ||||
| @ -148,8 +154,8 @@ struct Std140::State { | ||||
|         } | ||||
| 
 | ||||
|         // Scan uniform variables that have types that need forking
 | ||||
|         for (auto* decl : program->AST().GlobalVariables()) { | ||||
|             auto* global = program->Sem().Get(decl); | ||||
|         for (auto* decl : src->AST().GlobalVariables()) { | ||||
|             auto* global = src->Sem().Get(decl); | ||||
|             if (global->AddressSpace() == ast::AddressSpace::kUniform) { | ||||
|                 if (needs_fork(global->Type()->UnwrapRef())) { | ||||
|                     return true; | ||||
| @ -197,14 +203,16 @@ struct Std140::State { | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     /// Alias to the semantic info in ctx.src
 | ||||
|     const sem::Info& sem = ctx.src->Sem(); | ||||
|     /// Alias to the symbols in ctx.src
 | ||||
|     const SymbolTable& sym = ctx.src->Symbols(); | ||||
|     /// Alias to the ctx.dst program builder
 | ||||
|     ProgramBuilder& b = *ctx.dst; | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
|     /// Alias to the semantic info in src
 | ||||
|     const sem::Info& sem = src->Sem(); | ||||
|     /// Alias to the symbols in src
 | ||||
|     const SymbolTable& sym = src->Symbols(); | ||||
| 
 | ||||
|     /// Map of load function signature, to the generated function
 | ||||
|     utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns; | ||||
| @ -218,7 +226,7 @@ struct Std140::State { | ||||
|     // Map of original structure to 'std140' forked structure
 | ||||
|     utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs; | ||||
| 
 | ||||
|     // Map of structure member in ctx.src of a matrix type, to list of decomposed column
 | ||||
|     // Map of structure member in src of a matrix type, to list of decomposed column
 | ||||
|     // members in ctx.dst.
 | ||||
|     utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8> | ||||
|         std140_mat_members; | ||||
| @ -232,7 +240,7 @@ struct Std140::State { | ||||
|         utils::Vector<Symbol, 4> columns; | ||||
|     }; | ||||
| 
 | ||||
|     // Map of matrix type in ctx.src, to decomposed column structure in ctx.dst.
 | ||||
|     // Map of matrix type in src, to decomposed column structure in ctx.dst.
 | ||||
|     utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats; | ||||
| 
 | ||||
|     /// AccessChain describes a chain of access expressions to uniform buffer variable.
 | ||||
| @ -266,7 +274,7 @@ struct Std140::State { | ||||
|     /// map (via Std140Type()).
 | ||||
|     void ForkTypes() { | ||||
|         // For each module scope declaration...
 | ||||
|         for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) { | ||||
|         for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) { | ||||
|             // Check to see if this is a structure used by a uniform buffer...
 | ||||
|             auto* str = sem.Get<sem::Struct>(global); | ||||
|             if (str && str->UsedAs(ast::AddressSpace::kUniform)) { | ||||
| @ -317,7 +325,7 @@ struct Std140::State { | ||||
|                 if (fork_std140) { | ||||
|                     // Clone any members that have not already been cloned.
 | ||||
|                     for (auto& member : members) { | ||||
|                         if (member->program_id == ctx.src->ID()) { | ||||
|                         if (member->program_id == src->ID()) { | ||||
|                             member = ctx.Clone(member); | ||||
|                         } | ||||
|                     } | ||||
| @ -326,7 +334,7 @@ struct Std140::State { | ||||
|                     auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140"); | ||||
|                     auto* std140 = b.create<ast::Struct>(name, std::move(members), | ||||
|                                                          ctx.Clone(str->Declaration()->attributes)); | ||||
|                     ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140); | ||||
|                     ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140); | ||||
|                     std140_structs.Add(str, name); | ||||
|                 } | ||||
|             } | ||||
| @ -337,14 +345,13 @@ struct Std140::State { | ||||
|     /// type that has been forked for std140-layout.
 | ||||
|     /// Populates the #std140_uniforms set.
 | ||||
|     void ReplaceUniformVarTypes() { | ||||
|         for (auto* global : ctx.src->AST().GlobalVariables()) { | ||||
|         for (auto* global : src->AST().GlobalVariables()) { | ||||
|             if (auto* var = global->As<ast::Var>()) { | ||||
|                 if (var->declared_address_space == ast::AddressSpace::kUniform) { | ||||
|                     auto* v = sem.Get(var); | ||||
|                     if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) { | ||||
|                         ctx.Replace(global->type, std140_ty); | ||||
|                         std140_uniforms.Add(v); | ||||
|                         continue; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| @ -404,7 +411,7 @@ struct Std140::State { | ||||
|                     auto std140_mat = std140_mats.GetOrCreate(mat, [&] { | ||||
|                         auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" + | ||||
|                                                     std::to_string(mat->rows()) + "_" + | ||||
|                                                     ctx.src->FriendlyName(mat->type())); | ||||
|                                                     src->FriendlyName(mat->type())); | ||||
|                         auto members = | ||||
|                             DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size()); | ||||
|                         b.Structure(name, members); | ||||
| @ -421,7 +428,7 @@ struct Std140::State { | ||||
|                 if (auto* std140 = Std140Type(arr->ElemType())) { | ||||
|                     utils::Vector<const ast::Attribute*, 1> attrs; | ||||
|                     if (!arr->IsStrideImplicit()) { | ||||
|                         attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride())); | ||||
|                         attrs.Push(b.create<ast::StrideAttribute>(arr->Stride())); | ||||
|                     } | ||||
|                     auto count = arr->ConstantCount(); | ||||
|                     if (!count) { | ||||
| @ -429,7 +436,7 @@ struct Std140::State { | ||||
|                         // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||
|                         //   this method only handles types transitively used as uniform buffers.
 | ||||
|                         // * Runtime-sized arrays cannot be used in uniform buffers.
 | ||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                         TINT_ICE(Transform, b.Diagnostics()) | ||||
|                             << "unexpected non-constant array count"; | ||||
|                         count = 1; | ||||
|                     } | ||||
| @ -440,7 +447,7 @@ struct Std140::State { | ||||
|             }); | ||||
|     } | ||||
| 
 | ||||
|     /// @param mat the matrix to decompose (in ctx.src)
 | ||||
|     /// @param mat the matrix to decompose (in src)
 | ||||
|     /// @param name_prefix the name prefix to apply to each of the returned column vector members.
 | ||||
|     /// @param align the alignment in bytes of the matrix.
 | ||||
|     /// @param size the size in bytes of the matrix.
 | ||||
| @ -473,7 +480,7 @@ struct Std140::State { | ||||
|             // Build the member
 | ||||
|             const auto col_name = name_prefix + std::to_string(i); | ||||
|             const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType()); | ||||
|             const auto* col_member = ctx.dst->Member(col_name, col_ty, std::move(attributes)); | ||||
|             const auto* col_member = b.Member(col_name, col_ty, std::move(attributes)); | ||||
|             // Record the member for std140_mat_members
 | ||||
|             out.Push(col_member); | ||||
|         } | ||||
| @ -618,7 +625,7 @@ struct Std140::State { | ||||
| 
 | ||||
|     /// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
 | ||||
|     ///          being converted.
 | ||||
|     const std::string ConvertSuffix(const sem::Type* ty) const { | ||||
|     const std::string ConvertSuffix(const sem::Type* ty) { | ||||
|         return Switch( | ||||
|             ty,  //
 | ||||
|             [&](const sem::Struct* str) { return sym.NameFor(str->Name()); }, | ||||
| @ -629,8 +636,7 @@ struct Std140::State { | ||||
|                     // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||
|                     //   this method only handles types transitively used as uniform buffers.
 | ||||
|                     // * Runtime-sized arrays cannot be used in uniform buffers.
 | ||||
|                     TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                         << "unexpected non-constant array count"; | ||||
|                     TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; | ||||
|                     count = 1; | ||||
|                 } | ||||
|                 return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType()); | ||||
| @ -642,7 +648,7 @@ struct Std140::State { | ||||
|             [&](const sem::F32*) { return "f32"; }, | ||||
|             [&](Default) { | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "unhandled type for conversion name: " << ctx.src->FriendlyName(ty); | ||||
|                     << "unhandled type for conversion name: " << src->FriendlyName(ty); | ||||
|                 return ""; | ||||
|             }); | ||||
|     } | ||||
| @ -718,8 +724,7 @@ struct Std140::State { | ||||
|                         stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args)))); | ||||
|                     } else { | ||||
|                         TINT_ICE(Transform, b.Diagnostics()) | ||||
|                             << "failed to find std140 matrix info for: " | ||||
|                             << ctx.src->FriendlyName(ty); | ||||
|                             << "failed to find std140 matrix info for: " << src->FriendlyName(ty); | ||||
|                     } | ||||
|                 },  //
 | ||||
|                 [&](const sem::Array* arr) { | ||||
| @ -736,7 +741,7 @@ struct Std140::State { | ||||
|                         // * Override-expression counts can only be applied to workgroup arrays, and
 | ||||
|                         //   this method only handles types transitively used as uniform buffers.
 | ||||
|                         // * Runtime-sized arrays cannot be used in uniform buffers.
 | ||||
|                         TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                         TINT_ICE(Transform, b.Diagnostics()) | ||||
|                             << "unexpected non-constant array count"; | ||||
|                         count = 1; | ||||
|                     } | ||||
| @ -749,7 +754,7 @@ struct Std140::State { | ||||
|                 }, | ||||
|                 [&](Default) { | ||||
|                     TINT_ICE(Transform, b.Diagnostics()) | ||||
|                         << "unhandled type for conversion: " << ctx.src->FriendlyName(ty); | ||||
|                         << "unhandled type for conversion: " << src->FriendlyName(ty); | ||||
|                 }); | ||||
| 
 | ||||
|             // Generate the function
 | ||||
| @ -1063,7 +1068,7 @@ struct Std140::State { | ||||
| 
 | ||||
|         if (std::get_if<UniformVariable>(&access)) { | ||||
|             const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol)); | ||||
|             const auto name = ctx.src->Symbols().NameFor(chain.var->Declaration()->symbol); | ||||
|             const auto name = src->Symbols().NameFor(chain.var->Declaration()->symbol); | ||||
|             ty = chain.var->Type()->UnwrapRef(); | ||||
|             return {expr, ty, name}; | ||||
|         } | ||||
| @ -1090,7 +1095,7 @@ struct Std140::State { | ||||
|                 },  //
 | ||||
|                 [&](Default) -> ExprTypeName { | ||||
|                     TINT_ICE(Transform, b.Diagnostics()) | ||||
|                         << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); | ||||
|                         << "unhandled type for access chain: " << src->FriendlyName(ty); | ||||
|                     return {}; | ||||
|                 }); | ||||
|         } | ||||
| @ -1104,14 +1109,14 @@ struct Std140::State { | ||||
|                     for (auto el : *swizzle) { | ||||
|                         rhs += xyzw[el]; | ||||
|                     } | ||||
|                     auto swizzle_ty = ctx.src->Types().Find<sem::Vector>( | ||||
|                     auto swizzle_ty = src->Types().Find<sem::Vector>( | ||||
|                         vec->type(), static_cast<uint32_t>(swizzle->Length())); | ||||
|                     auto* expr = b.MemberAccessor(lhs, rhs); | ||||
|                     return {expr, swizzle_ty, rhs}; | ||||
|                 },  //
 | ||||
|                 [&](Default) -> ExprTypeName { | ||||
|                     TINT_ICE(Transform, b.Diagnostics()) | ||||
|                         << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); | ||||
|                         << "unhandled type for access chain: " << src->FriendlyName(ty); | ||||
|                     return {}; | ||||
|                 }); | ||||
|         } | ||||
| @ -1140,7 +1145,7 @@ struct Std140::State { | ||||
|             },  //
 | ||||
|             [&](Default) -> ExprTypeName { | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) | ||||
|                     << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); | ||||
|                     << "unhandled type for access chain: " << src->FriendlyName(ty); | ||||
|                 return {}; | ||||
|             }); | ||||
|     } | ||||
| @ -1150,12 +1155,8 @@ Std140::Std140() = default; | ||||
| 
 | ||||
| Std140::~Std140() = default; | ||||
| 
 | ||||
| bool Std140::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     return State::ShouldRun(program); | ||||
| } | ||||
| 
 | ||||
| void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State(ctx).Run(); | ||||
| Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     return State(src).Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -34,21 +34,13 @@ class Std140 final : public Castable<Std140, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~Std140() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| #include "src/tint/transform/substitute_override.h" | ||||
| 
 | ||||
| #include <functional> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/program_builder.h" | ||||
| #include "src/tint/sem/builtin.h" | ||||
| @ -25,12 +26,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride); | ||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| SubstituteOverride::SubstituteOverride() = default; | ||||
| 
 | ||||
| SubstituteOverride::~SubstituteOverride() = default; | ||||
| 
 | ||||
| bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->AST().GlobalVariables()) { | ||||
|         if (node->Is<ast::Override>()) { | ||||
|             return true; | ||||
| @ -39,18 +37,32 @@ bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) const { | ||||
| }  // namespace
 | ||||
| 
 | ||||
| SubstituteOverride::SubstituteOverride() = default; | ||||
| 
 | ||||
| SubstituteOverride::~SubstituteOverride() = default; | ||||
| 
 | ||||
| Transform::ApplyResult SubstituteOverride::Apply(const Program* src, | ||||
|                                                  const DataMap& config, | ||||
|                                                  DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     const auto* data = config.Get<Config>(); | ||||
|     if (!data) { | ||||
|         ctx.dst->Diagnostics().add_error(diag::System::Transform, | ||||
|                                          "Missing override substitution data"); | ||||
|         return; | ||||
|         b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data"); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|     if (!ShouldRun(ctx.src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* { | ||||
|         auto* sem = ctx.src->Sem().Get(w); | ||||
| 
 | ||||
|         auto src = ctx.Clone(w->source); | ||||
|         auto source = ctx.Clone(w->source); | ||||
|         auto sym = ctx.Clone(w->symbol); | ||||
|         auto* ty = ctx.Clone(w->type); | ||||
| 
 | ||||
| @ -58,30 +70,30 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) | ||||
|         auto iter = data->map.find(sem->OverrideId()); | ||||
|         if (iter == data->map.end()) { | ||||
|             if (!w->initializer) { | ||||
|                 ctx.dst->Diagnostics().add_error( | ||||
|                 b.Diagnostics().add_error( | ||||
|                     diag::System::Transform, | ||||
|                     "Initializer not provided for override, and override not overridden."); | ||||
|                 return nullptr; | ||||
|             } | ||||
|             return ctx.dst->Const(src, sym, ty, ctx.Clone(w->initializer)); | ||||
|             return b.Const(source, sym, ty, ctx.Clone(w->initializer)); | ||||
|         } | ||||
| 
 | ||||
|         auto value = iter->second; | ||||
|         auto* ctor = Switch( | ||||
|             sem->Type(), | ||||
|             [&](const sem::Bool*) { return ctx.dst->Expr(!std::equal_to<double>()(value, 0.0)); }, | ||||
|             [&](const sem::I32*) { return ctx.dst->Expr(i32(value)); }, | ||||
|             [&](const sem::U32*) { return ctx.dst->Expr(u32(value)); }, | ||||
|             [&](const sem::F32*) { return ctx.dst->Expr(f32(value)); }, | ||||
|             [&](const sem::F16*) { return ctx.dst->Expr(f16(value)); }); | ||||
|             [&](const sem::Bool*) { return b.Expr(!std::equal_to<double>()(value, 0.0)); }, | ||||
|             [&](const sem::I32*) { return b.Expr(i32(value)); }, | ||||
|             [&](const sem::U32*) { return b.Expr(u32(value)); }, | ||||
|             [&](const sem::F32*) { return b.Expr(f32(value)); }, | ||||
|             [&](const sem::F16*) { return b.Expr(f16(value)); }); | ||||
| 
 | ||||
|         if (!ctor) { | ||||
|             ctx.dst->Diagnostics().add_error(diag::System::Transform, | ||||
|                                              "Failed to create override-expression"); | ||||
|             b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                       "Failed to create override-expression"); | ||||
|             return nullptr; | ||||
|         } | ||||
| 
 | ||||
|         return ctx.dst->Const(src, sym, ty, ctor); | ||||
|         return b.Const(source, sym, ty, ctor); | ||||
|     }); | ||||
| 
 | ||||
|     // Ensure that objects that are indexed with an override-expression are materialized.
 | ||||
| @ -89,11 +101,10 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) | ||||
|     // resulting type of the index may change. See: crbug.com/tint/1697.
 | ||||
|     ctx.ReplaceAll( | ||||
|         [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { | ||||
|             if (auto* sem = ctx.src->Sem().Get(expr)) { | ||||
|             if (auto* sem = src->Sem().Get(expr)) { | ||||
|                 if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) { | ||||
|                     if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() && | ||||
|                         access->Index()->Stage() == sem::EvaluationStage::kOverride) { | ||||
|                         auto& b = *ctx.dst; | ||||
|                         auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize), | ||||
|                                            ctx.Clone(expr->object)); | ||||
|                         return b.IndexAccessor(obj, ctx.Clone(expr->index)); | ||||
| @ -104,6 +115,7 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) | ||||
|         }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| SubstituteOverride::Config::Config() = default; | ||||
|  | ||||
| @ -75,19 +75,10 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform> | ||||
|     /// Destructor
 | ||||
|     ~SubstituteOverride() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -122,7 +122,18 @@ class TransformTestBase : public BASE { | ||||
|         } | ||||
| 
 | ||||
|         const Transform& t = TRANSFORM(); | ||||
|         return t.ShouldRun(&program, data); | ||||
| 
 | ||||
|         DataMap outputs; | ||||
|         auto result = t.Apply(&program, data, outputs); | ||||
|         if (!result) { | ||||
|             return false; | ||||
|         } | ||||
|         if (!result->IsValid()) { | ||||
|             ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: " | ||||
|                           << result->Diagnostics().str(); | ||||
|             return true; | ||||
|         } | ||||
|         return result.has_value(); | ||||
|     } | ||||
| 
 | ||||
|     /// @param in the input WGSL source
 | ||||
|  | ||||
| @ -46,24 +46,19 @@ Output::Output(Program&& p) : program(std::move(p)) {} | ||||
| Transform::Transform() = default; | ||||
| Transform::~Transform() = default; | ||||
| 
 | ||||
| Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const { | ||||
|     ProgramBuilder builder; | ||||
|     CloneContext ctx(&builder, program); | ||||
| Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const { | ||||
|     Output output; | ||||
|     Run(ctx, data, output.data); | ||||
|     output.program = Program(std::move(builder)); | ||||
|     if (auto program = Apply(src, data, output.data)) { | ||||
|         output.program = std::move(program.value()); | ||||
|     } else { | ||||
|         ProgramBuilder b; | ||||
|         CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
|         ctx.Clone(); | ||||
|         output.program = Program(std::move(b)); | ||||
|     } | ||||
|     return output; | ||||
| } | ||||
| 
 | ||||
| void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics()) | ||||
|         << "Transform::Run() unimplemented for " << TypeInfo().name; | ||||
| } | ||||
| 
 | ||||
| bool Transform::ShouldRun(const Program*, const DataMap&) const { | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
| void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) { | ||||
|     auto* sem = ctx.src->Sem().Get(stmt); | ||||
|     if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) { | ||||
|  | ||||
| @ -158,26 +158,30 @@ class Transform : public Castable<Transform> { | ||||
|     /// Destructor
 | ||||
|     ~Transform() override; | ||||
| 
 | ||||
|     /// Runs the transform on `program`, returning the transformation result.
 | ||||
|     /// Runs the transform on @p program, returning the transformation result or a clone of
 | ||||
|     /// @p program.
 | ||||
|     /// @param program the source program to transform
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns the transformation result
 | ||||
|     virtual Output Run(const Program* program, const DataMap& data = {}) const; | ||||
|     Output Run(const Program* program, const DataMap& data = {}) const; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const; | ||||
|     /// The return value of Apply().
 | ||||
|     /// If SkipTransform (std::nullopt), then the transform is not needed to be run.
 | ||||
|     using ApplyResult = std::optional<Program>; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// Value returned from Apply() to indicate that the transform does not need to be run
 | ||||
|     static inline constexpr std::nullopt_t SkipTransform = std::nullopt; | ||||
| 
 | ||||
|     /// Runs the transform on `program`, return.
 | ||||
|     /// @param program the input program
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const; | ||||
|     /// @returns a transformed program, or std::nullopt if the transform didn't need to run.
 | ||||
|     virtual ApplyResult Apply(const Program* program, | ||||
|                               const DataMap& inputs, | ||||
|                               DataMap& outputs) const = 0; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Removes the statement `stmt` from the transformed program.
 | ||||
|     /// RemoveStatement handles edge cases, like statements in the initializer and
 | ||||
|     /// continuing of for-loops.
 | ||||
|  | ||||
| @ -23,7 +23,9 @@ namespace { | ||||
| 
 | ||||
| // Inherit from Transform so we have access to protected methods
 | ||||
| struct CreateASTTypeForTest : public testing::Test, public Transform { | ||||
|     Output Run(const Program*, const DataMap&) const override { return {}; } | ||||
|     ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) { | ||||
|         ProgramBuilder sem_type_builder; | ||||
|  | ||||
| @ -28,27 +28,32 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// The PIMPL state for the Unshadow transform
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct Unshadow::State { | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext& ctx; | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     /// Constructor
 | ||||
|     /// @param context the clone context
 | ||||
|     explicit State(CloneContext& context) : ctx(context) {} | ||||
|     /// @param program the source program
 | ||||
|     explicit State(const Program* program) : src(program) {} | ||||
| 
 | ||||
|     /// Performs the transformation
 | ||||
|     void Run() { | ||||
|         auto& sem = ctx.src->Sem(); | ||||
|     /// Runs the transform
 | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     Transform::ApplyResult Run() { | ||||
|         auto& sem = src->Sem(); | ||||
| 
 | ||||
|         // Maps a variable to its new name.
 | ||||
|         std::unordered_map<const sem::Variable*, Symbol> renamed_to; | ||||
|         utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to; | ||||
| 
 | ||||
|         auto rename = [&](const sem::Variable* v) -> const ast::Variable* { | ||||
|             auto* decl = v->Declaration(); | ||||
|             auto name = ctx.src->Symbols().NameFor(decl->symbol); | ||||
|             auto symbol = ctx.dst->Symbols().New(name); | ||||
|             renamed_to.emplace(v, symbol); | ||||
|             auto name = src->Symbols().NameFor(decl->symbol); | ||||
|             auto symbol = b.Symbols().New(name); | ||||
|             renamed_to.Add(v, symbol); | ||||
| 
 | ||||
|             auto source = ctx.Clone(decl->source); | ||||
|             auto* type = ctx.Clone(decl->type); | ||||
| @ -57,20 +62,20 @@ struct Unshadow::State { | ||||
|             return Switch( | ||||
|                 decl,  //
 | ||||
|                 [&](const ast::Var* var) { | ||||
|                     return ctx.dst->Var(source, symbol, type, var->declared_address_space, | ||||
|                                         var->declared_access, initializer, attributes); | ||||
|                     return b.Var(source, symbol, type, var->declared_address_space, | ||||
|                                  var->declared_access, initializer, attributes); | ||||
|                 }, | ||||
|                 [&](const ast::Let*) { | ||||
|                     return ctx.dst->Let(source, symbol, type, initializer, attributes); | ||||
|                     return b.Let(source, symbol, type, initializer, attributes); | ||||
|                 }, | ||||
|                 [&](const ast::Const*) { | ||||
|                     return ctx.dst->Const(source, symbol, type, initializer, attributes); | ||||
|                     return b.Const(source, symbol, type, initializer, attributes); | ||||
|                 }, | ||||
|                 [&](const ast::Parameter*) { | ||||
|                     return ctx.dst->Param(source, symbol, type, attributes); | ||||
|                 [&](const ast::Parameter*) {  //
 | ||||
|                     return b.Param(source, symbol, type, attributes); | ||||
|                 }, | ||||
|                 [&](Default) { | ||||
|                     TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|                     TINT_ICE(Transform, b.Diagnostics()) | ||||
|                         << "unexpected variable type: " << decl->TypeInfo().name; | ||||
|                     return nullptr; | ||||
|                 }); | ||||
| @ -92,14 +97,15 @@ struct Unshadow::State { | ||||
|         ctx.ReplaceAll( | ||||
|             [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* { | ||||
|                 if (auto* user = sem.Get<sem::VariableUser>(ident)) { | ||||
|                     auto it = renamed_to.find(user->Variable()); | ||||
|                     if (it != renamed_to.end()) { | ||||
|                         return ctx.dst->Expr(it->second); | ||||
|                     if (auto* renamed = renamed_to.Find(user->Variable())) { | ||||
|                         return b.Expr(*renamed); | ||||
|                     } | ||||
|                 } | ||||
|                 return nullptr; | ||||
|             }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| @ -107,8 +113,8 @@ Unshadow::Unshadow() = default; | ||||
| 
 | ||||
| Unshadow::~Unshadow() = default; | ||||
| 
 | ||||
| void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
|     State(ctx).Run(); | ||||
| Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     return State(src).Run(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -29,16 +29,13 @@ class Unshadow final : public Castable<Unshadow, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~Unshadow() override; | ||||
| 
 | ||||
|   protected: | ||||
|     struct State; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -35,7 +35,51 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions); | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| class State { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     auto& sem = program->Sem(); | ||||
|     for (auto* f : program->AST().Functions()) { | ||||
|         if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct UnwindDiscardFunctions::State { | ||||
|     /// Constructor
 | ||||
|     /// @param ctx_in the context
 | ||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|         ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { | ||||
|             // Iterate block statements and replace them as needed.
 | ||||
|             for (auto* stmt : block->statements) { | ||||
|                 if (auto* new_stmt = Statement(stmt)) { | ||||
|                     ctx.Replace(stmt, new_stmt); | ||||
|                 } | ||||
| 
 | ||||
|                 // Handle for loops, as they are the only other AST node that
 | ||||
|                 // contains statements outside of BlockStatements.
 | ||||
|                 if (auto* fl = stmt->As<ast::ForLoopStatement>()) { | ||||
|                     if (auto* new_stmt = Statement(fl->initializer)) { | ||||
|                         ctx.Replace(fl->initializer, new_stmt); | ||||
|                     } | ||||
|                     if (auto* new_stmt = Statement(fl->continuing)) { | ||||
|                         // NOTE: Should never reach here as we cannot discard in a
 | ||||
|                         // continuing block.
 | ||||
|                         ctx.Replace(fl->continuing, new_stmt); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             return nullptr; | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     CloneContext& ctx; | ||||
|     ProgramBuilder& b; | ||||
| @ -163,7 +207,7 @@ class State { | ||||
|     // Returns true if `stmt` is a for-loop initializer statement.
 | ||||
|     bool IsForLoopInitStatement(const ast::Statement* stmt) { | ||||
|         if (auto* sem_stmt = sem.Get(stmt)) { | ||||
|             if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) { | ||||
|             if (auto* sem_fl = tint::As<sem::ForLoopStatement>(sem_stmt->Parent())) { | ||||
|                 return sem_fl->Declaration()->initializer == stmt; | ||||
|             } | ||||
|         } | ||||
| @ -305,60 +349,26 @@ class State { | ||||
|                 return TryInsertAfter(s, sem_expr); | ||||
|             }); | ||||
|     } | ||||
| 
 | ||||
|   public: | ||||
|     /// Constructor
 | ||||
|     /// @param ctx_in the context
 | ||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     void Run() { | ||||
|         ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { | ||||
|             // Iterate block statements and replace them as needed.
 | ||||
|             for (auto* stmt : block->statements) { | ||||
|                 if (auto* new_stmt = Statement(stmt)) { | ||||
|                     ctx.Replace(stmt, new_stmt); | ||||
|                 } | ||||
| 
 | ||||
|                 // Handle for loops, as they are the only other AST node that
 | ||||
|                 // contains statements outside of BlockStatements.
 | ||||
|                 if (auto* fl = stmt->As<ast::ForLoopStatement>()) { | ||||
|                     if (auto* new_stmt = Statement(fl->initializer)) { | ||||
|                         ctx.Replace(fl->initializer, new_stmt); | ||||
|                     } | ||||
|                     if (auto* new_stmt = Statement(fl->continuing)) { | ||||
|                         // NOTE: Should never reach here as we cannot discard in a
 | ||||
|                         // continuing block.
 | ||||
|                         ctx.Replace(fl->continuing, new_stmt); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             return nullptr; | ||||
|         }); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| UnwindDiscardFunctions::UnwindDiscardFunctions() = default; | ||||
| UnwindDiscardFunctions::~UnwindDiscardFunctions() = default; | ||||
| 
 | ||||
| void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| Transform::ApplyResult UnwindDiscardFunctions::Apply(const Program* src, | ||||
|                                                      const DataMap&, | ||||
|                                                      DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     State state(ctx); | ||||
|     state.Run(); | ||||
| } | ||||
| 
 | ||||
| bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const { | ||||
|     auto& sem = program->Sem(); | ||||
|     for (auto* f : program->AST().Functions()) { | ||||
|         if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -44,19 +44,13 @@ class UnwindDiscardFunctions final : public Castable<UnwindDiscardFunctions, Tra | ||||
|     /// Destructor
 | ||||
|     ~UnwindDiscardFunctions() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
|   private: | ||||
|     struct State; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -30,7 +30,59 @@ | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| /// Private implementation of HoistToDeclBefore transform
 | ||||
| class HoistToDeclBefore::State { | ||||
| struct HoistToDeclBefore::State { | ||||
|     /// Constructor
 | ||||
|     /// @param ctx_in the clone context
 | ||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::Add()
 | ||||
|     bool Add(const sem::Expression* before_expr, | ||||
|              const ast::Expression* expr, | ||||
|              bool as_let, | ||||
|              const char* decl_name) { | ||||
|         auto name = b.Symbols().New(decl_name); | ||||
| 
 | ||||
|         if (as_let) { | ||||
|             auto builder = [this, expr, name] { | ||||
|                 return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr))); | ||||
|             }; | ||||
|             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { | ||||
|                 return false; | ||||
|             } | ||||
|         } else { | ||||
|             auto builder = [this, expr, name] { | ||||
|                 return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr))); | ||||
|             }; | ||||
|             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // Replace the initializer expression with a reference to the let
 | ||||
|         ctx.Replace(expr, b.Expr(name)); | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
 | ||||
|     bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { | ||||
|         if (stmt) { | ||||
|             auto builder = [stmt] { return stmt; }; | ||||
|             return InsertBeforeImpl(before_stmt, std::move(builder)); | ||||
|         } | ||||
|         return InsertBeforeImpl(before_stmt, Decompose{}); | ||||
|     } | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
 | ||||
|     bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) { | ||||
|         return InsertBeforeImpl(before_stmt, std::move(builder)); | ||||
|     } | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::Prepare()
 | ||||
|     bool Prepare(const sem::Expression* before_expr) { | ||||
|         return InsertBefore(before_expr->Stmt(), nullptr); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     CloneContext& ctx; | ||||
|     ProgramBuilder& b; | ||||
| 
 | ||||
| @ -215,6 +267,8 @@ class HoistToDeclBefore::State { | ||||
| 
 | ||||
|     template <typename BUILDER> | ||||
|     bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) { | ||||
|         (void)builder;  // Avoid 'unused parameter' warning due to 'if constexpr'
 | ||||
| 
 | ||||
|         auto* ip = before_stmt->Declaration(); | ||||
| 
 | ||||
|         auto* else_if = before_stmt->As<sem::IfStatement>(); | ||||
| @ -299,58 +353,6 @@ class HoistToDeclBefore::State { | ||||
|             << "unhandled expression parent statement type: " << parent->TypeInfo().name; | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|   public: | ||||
|     /// Constructor
 | ||||
|     /// @param ctx_in the clone context
 | ||||
|     explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::Add()
 | ||||
|     bool Add(const sem::Expression* before_expr, | ||||
|              const ast::Expression* expr, | ||||
|              bool as_let, | ||||
|              const char* decl_name) { | ||||
|         auto name = b.Symbols().New(decl_name); | ||||
| 
 | ||||
|         if (as_let) { | ||||
|             auto builder = [this, expr, name] { | ||||
|                 return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr))); | ||||
|             }; | ||||
|             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { | ||||
|                 return false; | ||||
|             } | ||||
|         } else { | ||||
|             auto builder = [this, expr, name] { | ||||
|                 return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr))); | ||||
|             }; | ||||
|             if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // Replace the initializer expression with a reference to the let
 | ||||
|         ctx.Replace(expr, b.Expr(name)); | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
 | ||||
|     bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { | ||||
|         if (stmt) { | ||||
|             auto builder = [stmt] { return stmt; }; | ||||
|             return InsertBeforeImpl(before_stmt, std::move(builder)); | ||||
|         } | ||||
|         return InsertBeforeImpl(before_stmt, Decompose{}); | ||||
|     } | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
 | ||||
|     bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) { | ||||
|         return InsertBeforeImpl(before_stmt, std::move(builder)); | ||||
|     } | ||||
| 
 | ||||
|     /// @copydoc HoistToDeclBefore::Prepare()
 | ||||
|     bool Prepare(const sem::Expression* before_expr) { | ||||
|         return InsertBefore(before_expr->Stmt(), nullptr); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {} | ||||
|  | ||||
| @ -77,7 +77,7 @@ class HoistToDeclBefore { | ||||
|     bool Prepare(const sem::Expression* before_expr); | ||||
| 
 | ||||
|   private: | ||||
|     class State; | ||||
|     struct State; | ||||
|     std::unique_ptr<State> state_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -13,6 +13,9 @@ | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include "src/tint/transform/var_for_dynamic_index.h" | ||||
| 
 | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/program_builder.h" | ||||
| #include "src/tint/transform/utils/hoist_to_decl_before.h" | ||||
| 
 | ||||
| @ -22,7 +25,12 @@ VarForDynamicIndex::VarForDynamicIndex() = default; | ||||
| 
 | ||||
| VarForDynamicIndex::~VarForDynamicIndex() = default; | ||||
| 
 | ||||
| void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src, | ||||
|                                                  const DataMap&, | ||||
|                                                  DataMap&) const { | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     HoistToDeclBefore hoist_to_decl_before(ctx); | ||||
| 
 | ||||
|     // Extracts array and matrix values that are dynamically indexed to a
 | ||||
| @ -30,7 +38,7 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const | ||||
|     auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) { | ||||
|         auto* index_expr = access_expr->index; | ||||
|         auto* object_expr = access_expr->object; | ||||
|         auto& sem = ctx.src->Sem(); | ||||
|         auto& sem = src->Sem(); | ||||
| 
 | ||||
|         if (sem.Get(index_expr)->ConstantValue()) { | ||||
|             // Index expression resolves to a compile time value.
 | ||||
| @ -49,15 +57,21 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const | ||||
|         return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index"); | ||||
|     }; | ||||
| 
 | ||||
|     for (auto* node : ctx.src->ASTNodes().Objects()) { | ||||
|     bool index_accessor_found = false; | ||||
|     for (auto* node : src->ASTNodes().Objects()) { | ||||
|         if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) { | ||||
|             if (!dynamic_index_to_var(access_expr)) { | ||||
|                 return; | ||||
|                 return Program(std::move(b)); | ||||
|             } | ||||
|             index_accessor_found = true; | ||||
|         } | ||||
|     } | ||||
|     if (!index_accessor_found) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -31,14 +31,10 @@ class VarForDynamicIndex : public Transform { | ||||
|     /// Destructor
 | ||||
|     ~VarForDynamicIndex() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -30,11 +30,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeMatrixConversions); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| 
 | ||||
| VectorizeMatrixConversions::VectorizeMatrixConversions() = default; | ||||
| namespace { | ||||
| 
 | ||||
| VectorizeMatrixConversions::~VectorizeMatrixConversions() = default; | ||||
| 
 | ||||
| bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* sem = program->Sem().Get<sem::Expression>(node)) { | ||||
|             if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) { | ||||
| @ -50,14 +48,29 @@ bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| }  // namespace
 | ||||
| 
 | ||||
| VectorizeMatrixConversions::VectorizeMatrixConversions() = default; | ||||
| 
 | ||||
| VectorizeMatrixConversions::~VectorizeMatrixConversions() = default; | ||||
| 
 | ||||
| Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src, | ||||
|                                                          const DataMap&, | ||||
|                                                          DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     using HelperFunctionKey = | ||||
|         utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>; | ||||
| 
 | ||||
|     std::unordered_map<HelperFunctionKey, Symbol> matrix_convs; | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { | ||||
|         auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|         auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|         auto* ty_conv = call->Target()->As<sem::TypeConversion>(); | ||||
|         if (!ty_conv) { | ||||
|             return nullptr; | ||||
| @ -72,16 +85,16 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap& | ||||
|             return nullptr; | ||||
|         } | ||||
| 
 | ||||
|         auto& src = args[0]; | ||||
|         auto& matrix = args[0]; | ||||
| 
 | ||||
|         auto* src_type = args[0]->Type()->UnwrapRef()->As<sem::Matrix>(); | ||||
|         auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>(); | ||||
|         if (!src_type) { | ||||
|             return nullptr; | ||||
|         } | ||||
| 
 | ||||
|         // The source and destination type of a matrix conversion must have a same shape.
 | ||||
|         if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) { | ||||
|             TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|             TINT_ICE(Transform, b.Diagnostics()) | ||||
|                 << "source and destination matrix has different shape in matrix conversion"; | ||||
|             return nullptr; | ||||
|         } | ||||
| @ -90,47 +103,45 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap& | ||||
|             utils::Vector<const ast::Expression*, 4> columns; | ||||
|             for (uint32_t c = 0; c < dst_type->columns(); c++) { | ||||
|                 auto* src_matrix_expr = src_expression_builder(); | ||||
|                 auto* src_column_expr = | ||||
|                     ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c))); | ||||
|                 columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), | ||||
|                                                 src_column_expr)); | ||||
|                 auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c))); | ||||
|                 columns.Push( | ||||
|                     b.Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr)); | ||||
|             } | ||||
|             return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns); | ||||
|             return b.Construct(CreateASTTypeFor(ctx, dst_type), columns); | ||||
|         }; | ||||
| 
 | ||||
|         // Replace the matrix conversion to column vector conversions and a matrix construction.
 | ||||
|         if (!src->HasSideEffects()) { | ||||
|         if (!matrix->HasSideEffects()) { | ||||
|             // Simply use the argument's declaration if it has no side effects.
 | ||||
|             return build_vectorized_conversion_expression([&]() {  //
 | ||||
|                 return ctx.Clone(src->Declaration()); | ||||
|                 return ctx.Clone(matrix->Declaration()); | ||||
|             }); | ||||
|         } else { | ||||
|             // If has side effects, use a helper function.
 | ||||
|             auto fn = | ||||
|                 utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] { | ||||
|                     auto name = | ||||
|                         ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) + | ||||
|                                                "x" + std::to_string(src_type->rows()) + "_" + | ||||
|                                                ctx.dst->FriendlyName(src_type->type()) + "_" + | ||||
|                                                ctx.dst->FriendlyName(dst_type->type())); | ||||
|                     ctx.dst->Func( | ||||
|                         name, | ||||
|                         utils::Vector{ | ||||
|                             ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)), | ||||
|                         }, | ||||
|                         CreateASTTypeFor(ctx, dst_type), | ||||
|                         utils::Vector{ | ||||
|                             ctx.dst->Return(build_vectorized_conversion_expression([&]() {  //
 | ||||
|                                 return ctx.dst->Expr("value"); | ||||
|                             })), | ||||
|                         }); | ||||
|                     auto name = b.Symbols().New( | ||||
|                         "convert_mat" + std::to_string(src_type->columns()) + "x" + | ||||
|                         std::to_string(src_type->rows()) + "_" + b.FriendlyName(src_type->type()) + | ||||
|                         "_" + b.FriendlyName(dst_type->type())); | ||||
|                     b.Func(name, | ||||
|                            utils::Vector{ | ||||
|                                b.Param("value", CreateASTTypeFor(ctx, src_type)), | ||||
|                            }, | ||||
|                            CreateASTTypeFor(ctx, dst_type), | ||||
|                            utils::Vector{ | ||||
|                                b.Return(build_vectorized_conversion_expression([&]() {  //
 | ||||
|                                    return b.Expr("value"); | ||||
|                                })), | ||||
|                            }); | ||||
|                     return name; | ||||
|                 }); | ||||
|             return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration())); | ||||
|             return b.Call(fn, ctx.Clone(args[0]->Declaration())); | ||||
|         } | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -28,19 +28,10 @@ class VectorizeMatrixConversions final : public Castable<VectorizeMatrixConversi | ||||
|     /// Destructor
 | ||||
|     ~VectorizeMatrixConversions() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -27,12 +27,9 @@ | ||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixInitializers); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default; | ||||
| 
 | ||||
| VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default; | ||||
| 
 | ||||
| bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (auto* call = program->Sem().Get<sem::Call>(node)) { | ||||
|             if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) { | ||||
| @ -46,11 +43,26 @@ bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| }  // namespace
 | ||||
| 
 | ||||
| VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default; | ||||
| 
 | ||||
| VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default; | ||||
| 
 | ||||
| Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src, | ||||
|                                                                 const DataMap&, | ||||
|                                                                 DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     std::unordered_map<const sem::Matrix*, Symbol> scalar_inits; | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { | ||||
|         auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|         auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); | ||||
|         auto* ty_init = call->Target()->As<sem::TypeInitializer>(); | ||||
|         if (!ty_init) { | ||||
|             return nullptr; | ||||
| @ -87,10 +99,10 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D | ||||
|                 } | ||||
| 
 | ||||
|                 // Construct the column vector.
 | ||||
|                 columns.Push(ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), | ||||
|                                           std::move(row_values))); | ||||
|                 columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), | ||||
|                                    std::move(row_values))); | ||||
|             } | ||||
|             return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns); | ||||
|             return b.Construct(CreateASTTypeFor(ctx, mat_type), columns); | ||||
|         }; | ||||
| 
 | ||||
|         if (args.Length() == 1) { | ||||
| @ -98,23 +110,22 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D | ||||
|             // This is done to ensure that the single argument value is only evaluated once, and
 | ||||
|             // with the correct expression evaluation order.
 | ||||
|             auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] { | ||||
|                 auto name = | ||||
|                     ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" + | ||||
|                                            std::to_string(mat_type->rows())); | ||||
|                 ctx.dst->Func(name, | ||||
|                               utils::Vector{ | ||||
|                                   // Single scalar parameter
 | ||||
|                                   ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())), | ||||
|                               }, | ||||
|                               CreateASTTypeFor(ctx, mat_type), | ||||
|                               utils::Vector{ | ||||
|                                   ctx.dst->Return(build_mat([&](uint32_t, uint32_t) {  //
 | ||||
|                                       return ctx.dst->Expr("value"); | ||||
|                                   })), | ||||
|                               }); | ||||
|                 auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) + | ||||
|                                             "x" + std::to_string(mat_type->rows())); | ||||
|                 b.Func(name, | ||||
|                        utils::Vector{ | ||||
|                            // Single scalar parameter
 | ||||
|                            b.Param("value", CreateASTTypeFor(ctx, mat_type->type())), | ||||
|                        }, | ||||
|                        CreateASTTypeFor(ctx, mat_type), | ||||
|                        utils::Vector{ | ||||
|                            b.Return(build_mat([&](uint32_t, uint32_t) {  //
 | ||||
|                                return b.Expr("value"); | ||||
|                            })), | ||||
|                        }); | ||||
|                 return name; | ||||
|             }); | ||||
|             return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration())); | ||||
|             return b.Call(fn, ctx.Clone(args[0]->Declaration())); | ||||
|         } | ||||
| 
 | ||||
|         if (args.Length() == mat_type->columns() * mat_type->rows()) { | ||||
| @ -123,12 +134,13 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
|         TINT_ICE(Transform, ctx.dst->Diagnostics()) | ||||
|         TINT_ICE(Transform, b.Diagnostics()) | ||||
|             << "matrix initializer has unexpected number of arguments"; | ||||
|         return nullptr; | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -29,19 +29,10 @@ class VectorizeScalarMatrixInitializers final | ||||
|     /// Destructor
 | ||||
|     ~VectorizeScalarMatrixInitializers() override; | ||||
| 
 | ||||
|     /// @param program the program to inspect
 | ||||
|     /// @param data optional extra transform-specific input data
 | ||||
|     /// @returns true if this transform should be run for the given program
 | ||||
|     bool ShouldRun(const Program* program, const DataMap& data = {}) const override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
| @ -201,13 +201,46 @@ DataType DataTypeOf(VertexFormat format) { | ||||
|     return {BaseType::kInvalid, 0}; | ||||
| } | ||||
| 
 | ||||
| struct State { | ||||
|     State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {} | ||||
|     State(const State&) = default; | ||||
|     ~State() = default; | ||||
| }  // namespace
 | ||||
| 
 | ||||
|     /// LocationReplacement describes an ast::Variable replacement for a
 | ||||
|     /// location input.
 | ||||
| /// PIMPL state for the transform
 | ||||
| struct VertexPulling::State { | ||||
|     /// Constructor
 | ||||
|     /// @param program the source program
 | ||||
|     /// @param c the VertexPulling config
 | ||||
|     State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {} | ||||
| 
 | ||||
|     /// Runs the transform
 | ||||
|     /// @returns the new program or SkipTransform if the transform is not required
 | ||||
|     ApplyResult Run() { | ||||
|         // Find entry point
 | ||||
|         const ast::Function* func = nullptr; | ||||
|         for (auto* fn : src->AST().Functions()) { | ||||
|             if (fn->PipelineStage() == ast::PipelineStage::kVertex) { | ||||
|                 if (func != nullptr) { | ||||
|                     b.Diagnostics().add_error( | ||||
|                         diag::System::Transform, | ||||
|                         "VertexPulling found more than one vertex entry point"); | ||||
|                     return Program(std::move(b)); | ||||
|                 } | ||||
|                 func = fn; | ||||
|             } | ||||
|         } | ||||
|         if (func == nullptr) { | ||||
|             b.Diagnostics().add_error(diag::System::Transform, | ||||
|                                       "Vertex stage entry point not found"); | ||||
|             return Program(std::move(b)); | ||||
|         } | ||||
| 
 | ||||
|         AddVertexStorageBuffers(); | ||||
|         Process(func); | ||||
| 
 | ||||
|         ctx.Clone(); | ||||
|         return Program(std::move(b)); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
|     /// LocationReplacement describes an ast::Variable replacement for a location input.
 | ||||
|     struct LocationReplacement { | ||||
|         /// The variable to replace in the source Program
 | ||||
|         ast::Variable* from; | ||||
| @ -215,13 +248,22 @@ struct State { | ||||
|         ast::Variable* to; | ||||
|     }; | ||||
| 
 | ||||
|     /// LocationInfo describes an input location
 | ||||
|     struct LocationInfo { | ||||
|         /// A builder that builds the expression that resolves to the (transformed) input location
 | ||||
|         std::function<const ast::Expression*()> expr; | ||||
|         /// The store type of the location variable
 | ||||
|         const sem::Type* type; | ||||
|     }; | ||||
| 
 | ||||
|     CloneContext& ctx; | ||||
|     /// The source program
 | ||||
|     const Program* const src; | ||||
|     /// The transform config
 | ||||
|     VertexPulling::Config const cfg; | ||||
|     /// The target program builder
 | ||||
|     ProgramBuilder b; | ||||
|     /// The clone context
 | ||||
|     CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; | ||||
|     std::unordered_map<uint32_t, LocationInfo> location_info; | ||||
|     std::function<const ast::Expression*()> vertex_index_expr = nullptr; | ||||
|     std::function<const ast::Expression*()> instance_index_expr = nullptr; | ||||
| @ -235,7 +277,7 @@ struct State { | ||||
|     Symbol GetVertexBufferName(uint32_t index) { | ||||
|         return utils::GetOrCreate(vertex_buffer_names, index, [&] { | ||||
|             static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_"; | ||||
|             return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); | ||||
|             return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
| @ -243,7 +285,7 @@ struct State { | ||||
|     Symbol GetStructBufferName() { | ||||
|         if (!struct_buffer_name.IsValid()) { | ||||
|             static const char kStructBufferName[] = "tint_vertex_data"; | ||||
|             struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName); | ||||
|             struct_buffer_name = b.Symbols().New(kStructBufferName); | ||||
|         } | ||||
|         return struct_buffer_name; | ||||
|     } | ||||
| @ -252,21 +294,19 @@ struct State { | ||||
|     void AddVertexStorageBuffers() { | ||||
|         // Creating the struct type
 | ||||
|         static const char kStructName[] = "TintVertexData"; | ||||
|         auto* struct_type = | ||||
|             ctx.dst->Structure(ctx.dst->Symbols().New(kStructName), | ||||
|                                utils::Vector{ | ||||
|                                    ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array<u32>()), | ||||
|                                }); | ||||
|         auto* struct_type = b.Structure(b.Symbols().New(kStructName), | ||||
|                                         utils::Vector{ | ||||
|                                             b.Member(GetStructBufferName(), b.ty.array<u32>()), | ||||
|                                         }); | ||||
|         for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { | ||||
|             // The decorated variable with struct type
 | ||||
|             ctx.dst->GlobalVar(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type), | ||||
|                                ast::AddressSpace::kStorage, ast::Access::kRead, | ||||
|                                ctx.dst->Binding(AInt(i)), ctx.dst->Group(AInt(cfg.pulling_group))); | ||||
|             b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage, | ||||
|                         ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group))); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Creates and returns the assignment to the variables from the buffers
 | ||||
|     ast::BlockStatement* CreateVertexPullingPreamble() { | ||||
|     const ast::BlockStatement* CreateVertexPullingPreamble() { | ||||
|         // Assign by looking at the vertex descriptor to find attributes with
 | ||||
|         // matching location.
 | ||||
| 
 | ||||
| @ -276,7 +316,7 @@ struct State { | ||||
|             const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx]; | ||||
| 
 | ||||
|             if ((buffer_layout.array_stride & 3) != 0) { | ||||
|                 ctx.dst->Diagnostics().add_error( | ||||
|                 b.Diagnostics().add_error( | ||||
|                     diag::System::Transform, | ||||
|                     "WebGPU requires that vertex stride must be a multiple of 4 bytes, " | ||||
|                     "but VertexPulling array stride for buffer " + | ||||
| @ -292,15 +332,15 @@ struct State { | ||||
|             // buffer_array_base is the base array offset for all the vertex
 | ||||
|             // attributes. These are units of uint (4 bytes).
 | ||||
|             auto buffer_array_base = | ||||
|                 ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); | ||||
|                 b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); | ||||
| 
 | ||||
|             auto* attribute_offset = index_expr; | ||||
|             if (buffer_layout.array_stride != 4) { | ||||
|                 attribute_offset = ctx.dst->Mul(index_expr, u32(buffer_layout.array_stride / 4u)); | ||||
|                 attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u)); | ||||
|             } | ||||
| 
 | ||||
|             // let pulling_offset_n = <attribute_offset>
 | ||||
|             stmts.Push(ctx.dst->Decl(ctx.dst->Let(buffer_array_base, attribute_offset))); | ||||
|             stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset))); | ||||
| 
 | ||||
|             for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) { | ||||
|                 auto it = location_info.find(attribute_desc.shader_location); | ||||
| @ -320,8 +360,8 @@ struct State { | ||||
|                     err << "VertexAttributeDescriptor for location " | ||||
|                         << std::to_string(attribute_desc.shader_location) << " has format " | ||||
|                         << attribute_desc.format << " but shader expects " | ||||
|                         << var.type->FriendlyName(ctx.src->Symbols()); | ||||
|                     ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str()); | ||||
|                         << var.type->FriendlyName(src->Symbols()); | ||||
|                     b.Diagnostics().add_error(diag::System::Transform, err.str()); | ||||
|                     return nullptr; | ||||
|                 } | ||||
| 
 | ||||
| @ -337,16 +377,16 @@ struct State { | ||||
|                     // WGSL variable vector width is smaller than the loaded vector width
 | ||||
|                     switch (var_dt.width) { | ||||
|                         case 1: | ||||
|                             value = ctx.dst->MemberAccessor(fetch, "x"); | ||||
|                             value = b.MemberAccessor(fetch, "x"); | ||||
|                             break; | ||||
|                         case 2: | ||||
|                             value = ctx.dst->MemberAccessor(fetch, "xy"); | ||||
|                             value = b.MemberAccessor(fetch, "xy"); | ||||
|                             break; | ||||
|                         case 3: | ||||
|                             value = ctx.dst->MemberAccessor(fetch, "xyz"); | ||||
|                             value = b.MemberAccessor(fetch, "xyz"); | ||||
|                             break; | ||||
|                         default: | ||||
|                             TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width; | ||||
|                             TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width; | ||||
|                             return nullptr; | ||||
|                     } | ||||
|                 } else if (var_dt.width > fmt_dt.width) { | ||||
| @ -355,32 +395,32 @@ struct State { | ||||
|                     utils::Vector<const ast::Expression*, 8> values{fetch}; | ||||
|                     switch (var_dt.base_type) { | ||||
|                         case BaseType::kI32: | ||||
|                             ty = ctx.dst->ty.i32(); | ||||
|                             ty = b.ty.i32(); | ||||
|                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | ||||
|                                 values.Push(ctx.dst->Expr((i == 3) ? 1_i : 0_i)); | ||||
|                                 values.Push(b.Expr((i == 3) ? 1_i : 0_i)); | ||||
|                             } | ||||
|                             break; | ||||
|                         case BaseType::kU32: | ||||
|                             ty = ctx.dst->ty.u32(); | ||||
|                             ty = b.ty.u32(); | ||||
|                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | ||||
|                                 values.Push(ctx.dst->Expr((i == 3) ? 1_u : 0_u)); | ||||
|                                 values.Push(b.Expr((i == 3) ? 1_u : 0_u)); | ||||
|                             } | ||||
|                             break; | ||||
|                         case BaseType::kF32: | ||||
|                             ty = ctx.dst->ty.f32(); | ||||
|                             ty = b.ty.f32(); | ||||
|                             for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { | ||||
|                                 values.Push(ctx.dst->Expr((i == 3) ? 1_f : 0_f)); | ||||
|                                 values.Push(b.Expr((i == 3) ? 1_f : 0_f)); | ||||
|                             } | ||||
|                             break; | ||||
|                         default: | ||||
|                             TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type; | ||||
|                             TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type; | ||||
|                             return nullptr; | ||||
|                     } | ||||
|                     value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values); | ||||
|                     value = b.Construct(b.ty.vec(ty, var_dt.width), values); | ||||
|                 } | ||||
| 
 | ||||
|                 // Assign the value to the WGSL variable
 | ||||
|                 stmts.Push(ctx.dst->Assign(var.expr(), value)); | ||||
|                 stmts.Push(b.Assign(var.expr(), value)); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| @ -388,7 +428,7 @@ struct State { | ||||
|             return nullptr; | ||||
|         } | ||||
| 
 | ||||
|         return ctx.dst->create<ast::BlockStatement>(std::move(stmts)); | ||||
|         return b.Block(std::move(stmts)); | ||||
|     } | ||||
| 
 | ||||
|     /// Generates an expression reading from a buffer a specific format.
 | ||||
| @ -407,7 +447,7 @@ struct State { | ||||
|         }; | ||||
| 
 | ||||
|         // Returns a i32 loaded from buffer_base + offset.
 | ||||
|         auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); }; | ||||
|         auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); }; | ||||
| 
 | ||||
|         // Returns a u32 loaded from buffer_base + offset + 4.
 | ||||
|         auto load_next_u32 = [&] { | ||||
| @ -415,7 +455,7 @@ struct State { | ||||
|         }; | ||||
| 
 | ||||
|         // Returns a i32 loaded from buffer_base + offset + 4.
 | ||||
|         auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); }; | ||||
|         auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); }; | ||||
| 
 | ||||
|         // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
 | ||||
|         // The low 16 bits are 0.
 | ||||
| @ -427,17 +467,17 @@ struct State { | ||||
|                 LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); | ||||
|             switch (offset & 3) { | ||||
|                 case 0: | ||||
|                     return ctx.dst->Shl(low_u32, 16_u); | ||||
|                     return b.Shl(low_u32, 16_u); | ||||
|                 case 1: | ||||
|                     return ctx.dst->And(ctx.dst->Shl(low_u32, 8_u), 0xffff0000_u); | ||||
|                     return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u); | ||||
|                 case 2: | ||||
|                     return ctx.dst->And(low_u32, 0xffff0000_u); | ||||
|                     return b.And(low_u32, 0xffff0000_u); | ||||
|                 default: {  // 3:
 | ||||
|                     auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, | ||||
|                                                    VertexFormat::kUint32); | ||||
|                     auto* shr = ctx.dst->Shr(low_u32, 8_u); | ||||
|                     auto* shl = ctx.dst->Shl(high_u32, 24_u); | ||||
|                     return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000_u); | ||||
|                     auto* shr = b.Shr(low_u32, 8_u); | ||||
|                     auto* shl = b.Shl(high_u32, 24_u); | ||||
|                     return b.And(b.Or(shl, shr), 0xffff0000_u); | ||||
|                 } | ||||
|             } | ||||
|         }; | ||||
| @ -450,24 +490,24 @@ struct State { | ||||
|                 LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); | ||||
|             switch (offset & 3) { | ||||
|                 case 0: | ||||
|                     return ctx.dst->And(low_u32, 0xffff_u); | ||||
|                     return b.And(low_u32, 0xffff_u); | ||||
|                 case 1: | ||||
|                     return ctx.dst->And(ctx.dst->Shr(low_u32, 8_u), 0xffff_u); | ||||
|                     return b.And(b.Shr(low_u32, 8_u), 0xffff_u); | ||||
|                 case 2: | ||||
|                     return ctx.dst->Shr(low_u32, 16_u); | ||||
|                     return b.Shr(low_u32, 16_u); | ||||
|                 default: {  // 3:
 | ||||
|                     auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, | ||||
|                                                    VertexFormat::kUint32); | ||||
|                     auto* shr = ctx.dst->Shr(low_u32, 24_u); | ||||
|                     auto* shl = ctx.dst->Shl(high_u32, 8_u); | ||||
|                     return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff_u); | ||||
|                     auto* shr = b.Shr(low_u32, 24_u); | ||||
|                     auto* shl = b.Shl(high_u32, 8_u); | ||||
|                     return b.And(b.Or(shl, shr), 0xffff_u); | ||||
|                 } | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
 | ||||
|         // The low 16 bits are 0.
 | ||||
|         auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); }; | ||||
|         auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); }; | ||||
| 
 | ||||
|         // Assumptions are made that alignment must be at least as large as the size
 | ||||
|         // of a single component.
 | ||||
| @ -480,128 +520,121 @@ struct State { | ||||
| 
 | ||||
|                 // Vectors of basic primitives
 | ||||
|             case VertexFormat::kUint32x2: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), | ||||
|                                VertexFormat::kUint32, 2); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2); | ||||
|             case VertexFormat::kUint32x3: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), | ||||
|                                VertexFormat::kUint32, 3); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3); | ||||
|             case VertexFormat::kUint32x4: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), | ||||
|                                VertexFormat::kUint32, 4); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4); | ||||
|             case VertexFormat::kSint32x2: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), | ||||
|                                VertexFormat::kSint32, 2); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2); | ||||
|             case VertexFormat::kSint32x3: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), | ||||
|                                VertexFormat::kSint32, 3); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3); | ||||
|             case VertexFormat::kSint32x4: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), | ||||
|                                VertexFormat::kSint32, 4); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4); | ||||
|             case VertexFormat::kFloat32x2: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), | ||||
|                                VertexFormat::kFloat32, 2); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, | ||||
|                                2); | ||||
|             case VertexFormat::kFloat32x3: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), | ||||
|                                VertexFormat::kFloat32, 3); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, | ||||
|                                3); | ||||
|             case VertexFormat::kFloat32x4: | ||||
|                 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), | ||||
|                                VertexFormat::kFloat32, 4); | ||||
|                 return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, | ||||
|                                4); | ||||
| 
 | ||||
|             case VertexFormat::kUint8x2: { | ||||
|                 // yyxx0000, yyxx0000
 | ||||
|                 auto* u16s = ctx.dst->vec2<u32>(load_u16_h()); | ||||
|                 auto* u16s = b.vec2<u32>(load_u16_h()); | ||||
|                 // xx000000, yyxx0000
 | ||||
|                 auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8_u, 0_u)); | ||||
|                 auto* shl = b.Shl(u16s, b.vec2<u32>(8_u, 0_u)); | ||||
|                 // 000000xx, 000000yy
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u)); | ||||
|                 return b.Shr(shl, b.vec2<u32>(24_u)); | ||||
|             } | ||||
|             case VertexFormat::kUint8x4: { | ||||
|                 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
 | ||||
|                 auto* u32s = ctx.dst->vec4<u32>(load_u32()); | ||||
|                 auto* u32s = b.vec4<u32>(load_u32()); | ||||
|                 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
 | ||||
|                 auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u)); | ||||
|                 auto* shl = b.Shl(u32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u)); | ||||
|                 // 000000xx, 000000yy, 000000zz, 000000ww
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u)); | ||||
|                 return b.Shr(shl, b.vec4<u32>(24_u)); | ||||
|             } | ||||
|             case VertexFormat::kUint16x2: { | ||||
|                 // yyyyxxxx, yyyyxxxx
 | ||||
|                 auto* u32s = ctx.dst->vec2<u32>(load_u32()); | ||||
|                 auto* u32s = b.vec2<u32>(load_u32()); | ||||
|                 // xxxx0000, yyyyxxxx
 | ||||
|                 auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16_u, 0_u)); | ||||
|                 auto* shl = b.Shl(u32s, b.vec2<u32>(16_u, 0_u)); | ||||
|                 // 0000xxxx, 0000yyyy
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u)); | ||||
|                 return b.Shr(shl, b.vec2<u32>(16_u)); | ||||
|             } | ||||
|             case VertexFormat::kUint16x4: { | ||||
|                 // yyyyxxxx, wwwwzzzz
 | ||||
|                 auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32()); | ||||
|                 auto* u32s = b.vec2<u32>(load_u32(), load_next_u32()); | ||||
|                 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
 | ||||
|                 auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy"); | ||||
|                 auto* xxyy = b.MemberAccessor(u32s, "xxyy"); | ||||
|                 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
 | ||||
|                 auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u)); | ||||
|                 auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u)); | ||||
|                 // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u)); | ||||
|                 return b.Shr(shl, b.vec4<u32>(16_u)); | ||||
|             } | ||||
|             case VertexFormat::kSint8x2: { | ||||
|                 // yyxx0000, yyxx0000
 | ||||
|                 auto* i16s = ctx.dst->vec2<i32>(load_i16_h()); | ||||
|                 auto* i16s = b.vec2<i32>(load_i16_h()); | ||||
|                 // xx000000, yyxx0000
 | ||||
|                 auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8_u, 0_u)); | ||||
|                 auto* shl = b.Shl(i16s, b.vec2<u32>(8_u, 0_u)); | ||||
|                 // ssssssxx, ssssssyy
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u)); | ||||
|                 return b.Shr(shl, b.vec2<u32>(24_u)); | ||||
|             } | ||||
|             case VertexFormat::kSint8x4: { | ||||
|                 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
 | ||||
|                 auto* i32s = ctx.dst->vec4<i32>(load_i32()); | ||||
|                 auto* i32s = b.vec4<i32>(load_i32()); | ||||
|                 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
 | ||||
|                 auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u)); | ||||
|                 auto* shl = b.Shl(i32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u)); | ||||
|                 // ssssssxx, ssssssyy, sssssszz, ssssssww
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u)); | ||||
|                 return b.Shr(shl, b.vec4<u32>(24_u)); | ||||
|             } | ||||
|             case VertexFormat::kSint16x2: { | ||||
|                 // yyyyxxxx, yyyyxxxx
 | ||||
|                 auto* i32s = ctx.dst->vec2<i32>(load_i32()); | ||||
|                 auto* i32s = b.vec2<i32>(load_i32()); | ||||
|                 // xxxx0000, yyyyxxxx
 | ||||
|                 auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16_u, 0_u)); | ||||
|                 auto* shl = b.Shl(i32s, b.vec2<u32>(16_u, 0_u)); | ||||
|                 // ssssxxxx, ssssyyyy
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u)); | ||||
|                 return b.Shr(shl, b.vec2<u32>(16_u)); | ||||
|             } | ||||
|             case VertexFormat::kSint16x4: { | ||||
|                 // yyyyxxxx, wwwwzzzz
 | ||||
|                 auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32()); | ||||
|                 auto* i32s = b.vec2<i32>(load_i32(), load_next_i32()); | ||||
|                 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
 | ||||
|                 auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy"); | ||||
|                 auto* xxyy = b.MemberAccessor(i32s, "xxyy"); | ||||
|                 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
 | ||||
|                 auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u)); | ||||
|                 auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u)); | ||||
|                 // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
 | ||||
|                 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u)); | ||||
|                 return b.Shr(shl, b.vec4<u32>(16_u)); | ||||
|             } | ||||
|             case VertexFormat::kUnorm8x2: | ||||
|                 return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy"); | ||||
|                 return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy"); | ||||
|             case VertexFormat::kSnorm8x2: | ||||
|                 return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy"); | ||||
|                 return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy"); | ||||
|             case VertexFormat::kUnorm8x4: | ||||
|                 return ctx.dst->Call("unpack4x8unorm", load_u32()); | ||||
|                 return b.Call("unpack4x8unorm", load_u32()); | ||||
|             case VertexFormat::kSnorm8x4: | ||||
|                 return ctx.dst->Call("unpack4x8snorm", load_u32()); | ||||
|                 return b.Call("unpack4x8snorm", load_u32()); | ||||
|             case VertexFormat::kUnorm16x2: | ||||
|                 return ctx.dst->Call("unpack2x16unorm", load_u32()); | ||||
|                 return b.Call("unpack2x16unorm", load_u32()); | ||||
|             case VertexFormat::kSnorm16x2: | ||||
|                 return ctx.dst->Call("unpack2x16snorm", load_u32()); | ||||
|                 return b.Call("unpack2x16snorm", load_u32()); | ||||
|             case VertexFormat::kFloat16x2: | ||||
|                 return ctx.dst->Call("unpack2x16float", load_u32()); | ||||
|                 return b.Call("unpack2x16float", load_u32()); | ||||
|             case VertexFormat::kUnorm16x4: | ||||
|                 return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16unorm", load_u32()), | ||||
|                                           ctx.dst->Call("unpack2x16unorm", load_next_u32())); | ||||
|                 return b.vec4<f32>(b.Call("unpack2x16unorm", load_u32()), | ||||
|                                    b.Call("unpack2x16unorm", load_next_u32())); | ||||
|             case VertexFormat::kSnorm16x4: | ||||
|                 return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16snorm", load_u32()), | ||||
|                                           ctx.dst->Call("unpack2x16snorm", load_next_u32())); | ||||
|                 return b.vec4<f32>(b.Call("unpack2x16snorm", load_u32()), | ||||
|                                    b.Call("unpack2x16snorm", load_next_u32())); | ||||
|             case VertexFormat::kFloat16x4: | ||||
|                 return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16float", load_u32()), | ||||
|                                           ctx.dst->Call("unpack2x16float", load_next_u32())); | ||||
|                 return b.vec4<f32>(b.Call("unpack2x16float", load_u32()), | ||||
|                                    b.Call("unpack2x16float", load_next_u32())); | ||||
|         } | ||||
| 
 | ||||
|         TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | ||||
|             << "format " << static_cast<int>(format); | ||||
|         TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format); | ||||
|         return nullptr; | ||||
|     } | ||||
| 
 | ||||
| @ -623,12 +656,12 @@ struct State { | ||||
| 
 | ||||
|             const ast ::Expression* index = nullptr; | ||||
|             if (offset > 0) { | ||||
|                 index = ctx.dst->Add(array_base, u32(offset / 4)); | ||||
|                 index = b.Add(array_base, u32(offset / 4)); | ||||
|             } else { | ||||
|                 index = ctx.dst->Expr(array_base); | ||||
|                 index = b.Expr(array_base); | ||||
|             } | ||||
|             u = ctx.dst->IndexAccessor( | ||||
|                 ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); | ||||
|             u = b.IndexAccessor( | ||||
|                 b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); | ||||
| 
 | ||||
|         } else { | ||||
|             // Unaligned load
 | ||||
| @ -639,22 +672,22 @@ struct State { | ||||
| 
 | ||||
|             uint32_t shift = 8u * (offset & 3u); | ||||
| 
 | ||||
|             auto* low_shr = ctx.dst->Shr(low, u32(shift)); | ||||
|             auto* high_shl = ctx.dst->Shl(high, u32(32u - shift)); | ||||
|             u = ctx.dst->Or(low_shr, high_shl); | ||||
|             auto* low_shr = b.Shr(low, u32(shift)); | ||||
|             auto* high_shl = b.Shl(high, u32(32u - shift)); | ||||
|             u = b.Or(low_shr, high_shl); | ||||
|         } | ||||
| 
 | ||||
|         switch (format) { | ||||
|             case VertexFormat::kUint32: | ||||
|                 return u; | ||||
|             case VertexFormat::kSint32: | ||||
|                 return ctx.dst->Bitcast(ctx.dst->ty.i32(), u); | ||||
|                 return b.Bitcast(b.ty.i32(), u); | ||||
|             case VertexFormat::kFloat32: | ||||
|                 return ctx.dst->Bitcast(ctx.dst->ty.f32(), u); | ||||
|                 return b.Bitcast(b.ty.f32(), u); | ||||
|             default: | ||||
|                 break; | ||||
|         } | ||||
|         TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) | ||||
|         TINT_UNREACHABLE(Transform, b.Diagnostics()) | ||||
|             << "invalid format for LoadPrimitive" << static_cast<int>(format); | ||||
|         return nullptr; | ||||
|     } | ||||
| @ -682,8 +715,7 @@ struct State { | ||||
|             expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format)); | ||||
|         } | ||||
| 
 | ||||
|         return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count), | ||||
|                                   std::move(expr_list)); | ||||
|         return b.Construct(b.create<ast::Vector>(base_type, count), std::move(expr_list)); | ||||
|     } | ||||
| 
 | ||||
|     /// Process a non-struct entry point parameter.
 | ||||
| @ -696,34 +728,30 @@ struct State { | ||||
|             // Create a function-scope variable to replace the parameter.
 | ||||
|             auto func_var_sym = ctx.Clone(param->symbol); | ||||
|             auto* func_var_type = ctx.Clone(param->type); | ||||
|             auto* func_var = ctx.dst->Var(func_var_sym, func_var_type); | ||||
|             ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var)); | ||||
|             auto* func_var = b.Var(func_var_sym, func_var_type); | ||||
|             ctx.InsertFront(func->body->statements, b.Decl(func_var)); | ||||
|             // Capture mapping from location to the new variable.
 | ||||
|             LocationInfo info; | ||||
|             info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); }; | ||||
|             info.expr = [this, func_var]() { return b.Expr(func_var); }; | ||||
| 
 | ||||
|             auto* sem = ctx.src->Sem().Get<sem::Parameter>(param); | ||||
|             auto* sem = src->Sem().Get<sem::Parameter>(param); | ||||
|             info.type = sem->Type(); | ||||
| 
 | ||||
|             if (!sem->Location().has_value()) { | ||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value"; | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) << "Location missing value"; | ||||
|                 return; | ||||
|             } | ||||
|             location_info[sem->Location().value()] = info; | ||||
|         } else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) { | ||||
|             // Check for existing vertex_index and instance_index builtins.
 | ||||
|             if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { | ||||
|                 vertex_index_expr = [this, param]() { | ||||
|                     return ctx.dst->Expr(ctx.Clone(param->symbol)); | ||||
|                 }; | ||||
|                 vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; | ||||
|             } else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) { | ||||
|                 instance_index_expr = [this, param]() { | ||||
|                     return ctx.dst->Expr(ctx.Clone(param->symbol)); | ||||
|                 }; | ||||
|                 instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; | ||||
|             } | ||||
|             new_function_parameters.Push(ctx.Clone(param)); | ||||
|         } else { | ||||
|             TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter"; | ||||
|             TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| @ -746,7 +774,7 @@ struct State { | ||||
|         for (auto* member : struct_ty->members) { | ||||
|             auto member_sym = ctx.Clone(member->symbol); | ||||
|             std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() { | ||||
|                 return ctx.dst->MemberAccessor(param_sym, member_sym); | ||||
|                 return b.MemberAccessor(param_sym, member_sym); | ||||
|             }; | ||||
| 
 | ||||
|             if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) { | ||||
| @ -754,7 +782,7 @@ struct State { | ||||
|                 LocationInfo info; | ||||
|                 info.expr = member_expr; | ||||
| 
 | ||||
|                 auto* sem = ctx.src->Sem().Get(member); | ||||
|                 auto* sem = src->Sem().Get(member); | ||||
|                 info.type = sem->Type(); | ||||
| 
 | ||||
|                 TINT_ASSERT(Transform, sem->Location().has_value()); | ||||
| @ -770,7 +798,7 @@ struct State { | ||||
|                 } | ||||
|                 members_to_clone.Push(member); | ||||
|             } else { | ||||
|                 TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter"; | ||||
|                 TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| @ -781,8 +809,8 @@ struct State { | ||||
|         } | ||||
| 
 | ||||
|         // Create a function-scope variable to replace the parameter.
 | ||||
|         auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type)); | ||||
|         ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var)); | ||||
|         auto* func_var = b.Var(param_sym, ctx.Clone(param->type)); | ||||
|         ctx.InsertFront(func->body->statements, b.Decl(func_var)); | ||||
| 
 | ||||
|         if (!members_to_clone.IsEmpty()) { | ||||
|             // Create a new struct without the location attributes.
 | ||||
| @ -791,20 +819,20 @@ struct State { | ||||
|                 auto member_sym = ctx.Clone(member->symbol); | ||||
|                 auto* member_type = ctx.Clone(member->type); | ||||
|                 auto member_attrs = ctx.Clone(member->attributes); | ||||
|                 new_members.Push(ctx.dst->Member(member_sym, member_type, std::move(member_attrs))); | ||||
|                 new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs))); | ||||
|             } | ||||
|             auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members); | ||||
|             auto* new_struct = b.Structure(b.Sym(), new_members); | ||||
| 
 | ||||
|             // Create a new function parameter with this struct.
 | ||||
|             auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct)); | ||||
|             auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct)); | ||||
|             new_function_parameters.Push(new_param); | ||||
| 
 | ||||
|             // Copy values from the new parameter to the function-scope variable.
 | ||||
|             for (auto* member : members_to_clone) { | ||||
|                 auto member_name = ctx.Clone(member->symbol); | ||||
|                 ctx.InsertFront(func->body->statements, | ||||
|                                 ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name), | ||||
|                                                 ctx.dst->MemberAccessor(new_param, member_name))); | ||||
|                                 b.Assign(b.MemberAccessor(func_var, member_name), | ||||
|                                          b.MemberAccessor(new_param, member_name))); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @ -818,7 +846,7 @@ struct State { | ||||
| 
 | ||||
|         // Process entry point parameters.
 | ||||
|         for (auto* param : func->params) { | ||||
|             auto* sem = ctx.src->Sem().Get(param); | ||||
|             auto* sem = src->Sem().Get(param); | ||||
|             if (auto* str = sem->Type()->As<sem::Struct>()) { | ||||
|                 ProcessStructParameter(func, param, str->Declaration()); | ||||
|             } else { | ||||
| @ -830,11 +858,11 @@ struct State { | ||||
|         if (!vertex_index_expr) { | ||||
|             for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { | ||||
|                 if (layout.step_mode == VertexStepMode::kVertex) { | ||||
|                     auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index"); | ||||
|                     new_function_parameters.Push(ctx.dst->Param( | ||||
|                         name, ctx.dst->ty.u32(), | ||||
|                         utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kVertexIndex)})); | ||||
|                     vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); }; | ||||
|                     auto name = b.Symbols().New("tint_pulling_vertex_index"); | ||||
|                     new_function_parameters.Push( | ||||
|                         b.Param(name, b.ty.u32(), | ||||
|                                 utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)})); | ||||
|                     vertex_index_expr = [this, name]() { return b.Expr(name); }; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
| @ -842,11 +870,11 @@ struct State { | ||||
|         if (!instance_index_expr) { | ||||
|             for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { | ||||
|                 if (layout.step_mode == VertexStepMode::kInstance) { | ||||
|                     auto name = ctx.dst->Symbols().New("tint_pulling_instance_index"); | ||||
|                     new_function_parameters.Push(ctx.dst->Param( | ||||
|                         name, ctx.dst->ty.u32(), | ||||
|                         utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kInstanceIndex)})); | ||||
|                     instance_index_expr = [this, name]() { return ctx.dst->Expr(name); }; | ||||
|                     auto name = b.Symbols().New("tint_pulling_instance_index"); | ||||
|                     new_function_parameters.Push( | ||||
|                         b.Param(name, b.ty.u32(), | ||||
|                                 utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)})); | ||||
|                     instance_index_expr = [this, name]() { return b.Expr(name); }; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
| @ -864,53 +892,24 @@ struct State { | ||||
|         auto attrs = ctx.Clone(func->attributes); | ||||
|         auto ret_attrs = ctx.Clone(func->return_type_attributes); | ||||
|         auto* new_func = | ||||
|             ctx.dst->create<ast::Function>(func->source, func_sym, new_function_parameters, | ||||
|                                            ret_type, body, std::move(attrs), std::move(ret_attrs)); | ||||
|             b.create<ast::Function>(func->source, func_sym, new_function_parameters, ret_type, body, | ||||
|                                     std::move(attrs), std::move(ret_attrs)); | ||||
|         ctx.Replace(func, new_func); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| VertexPulling::VertexPulling() = default; | ||||
| VertexPulling::~VertexPulling() = default; | ||||
| 
 | ||||
| void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { | ||||
| Transform::ApplyResult VertexPulling::Apply(const Program* src, | ||||
|                                             const DataMap& inputs, | ||||
|                                             DataMap&) const { | ||||
|     auto cfg = cfg_; | ||||
|     if (auto* cfg_data = inputs.Get<Config>()) { | ||||
|         cfg = *cfg_data; | ||||
|     } | ||||
| 
 | ||||
|     // Find entry point
 | ||||
|     const ast::Function* func = nullptr; | ||||
|     for (auto* fn : ctx.src->AST().Functions()) { | ||||
|         if (fn->PipelineStage() == ast::PipelineStage::kVertex) { | ||||
|             if (func != nullptr) { | ||||
|                 ctx.dst->Diagnostics().add_error( | ||||
|                     diag::System::Transform, | ||||
|                     "VertexPulling found more than one vertex entry point"); | ||||
|                 return; | ||||
|             } | ||||
|             func = fn; | ||||
|         } | ||||
|     } | ||||
|     if (func == nullptr) { | ||||
|         ctx.dst->Diagnostics().add_error(diag::System::Transform, | ||||
|                                          "Vertex stage entry point not found"); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     // TODO(idanr): Need to check shader locations in descriptor cover all
 | ||||
|     // attributes
 | ||||
| 
 | ||||
|     // TODO(idanr): Make sure we covered all error cases, to guarantee the
 | ||||
|     // following stages will pass
 | ||||
| 
 | ||||
|     State state{ctx, cfg}; | ||||
|     state.AddVertexStorageBuffers(); | ||||
|     state.Process(func); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return State{src, cfg}.Run(); | ||||
| } | ||||
| 
 | ||||
| VertexPulling::Config::Config() = default; | ||||
|  | ||||
| @ -171,16 +171,14 @@ class VertexPulling final : public Castable<VertexPulling, Transform> { | ||||
|     /// Destructor
 | ||||
|     ~VertexPulling() override; | ||||
| 
 | ||||
|   protected: | ||||
|     /// Runs the transform using the CloneContext built for transforming a
 | ||||
|     /// program. Run() is responsible for calling Clone() on the CloneContext.
 | ||||
|     /// @param ctx the CloneContext primed with the input program and
 | ||||
|     /// ProgramBuilder
 | ||||
|     /// @param inputs optional extra transform-specific input data
 | ||||
|     /// @param outputs optional extra transform-specific output data
 | ||||
|     void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; | ||||
|     /// @copydoc Transform::Apply
 | ||||
|     ApplyResult Apply(const Program* program, | ||||
|                       const DataMap& inputs, | ||||
|                       DataMap& outputs) const override; | ||||
| 
 | ||||
|   private: | ||||
|     struct State; | ||||
| 
 | ||||
|     Config cfg_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -14,18 +14,17 @@ | ||||
| 
 | ||||
| #include "src/tint/transform/while_to_loop.h" | ||||
| 
 | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "src/tint/ast/break_statement.h" | ||||
| #include "src/tint/program_builder.h" | ||||
| 
 | ||||
| TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop); | ||||
| 
 | ||||
| namespace tint::transform { | ||||
| namespace { | ||||
| 
 | ||||
| WhileToLoop::WhileToLoop() = default; | ||||
| 
 | ||||
| WhileToLoop::~WhileToLoop() = default; | ||||
| 
 | ||||
| bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const { | ||||
| bool ShouldRun(const Program* program) { | ||||
|     for (auto* node : program->ASTNodes().Objects()) { | ||||
|         if (node->Is<ast::WhileStatement>()) { | ||||
|             return true; | ||||
| @ -34,20 +33,32 @@ bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const { | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| }  // namespace
 | ||||
| 
 | ||||
| WhileToLoop::WhileToLoop() = default; | ||||
| 
 | ||||
| WhileToLoop::~WhileToLoop() = default; | ||||
| 
 | ||||
| Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const { | ||||
|     if (!ShouldRun(src)) { | ||||
|         return SkipTransform; | ||||
|     } | ||||
| 
 | ||||
|     ProgramBuilder b; | ||||
|     CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; | ||||
| 
 | ||||
|     ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* { | ||||
|         utils::Vector<const ast::Statement*, 16> stmts; | ||||
|         auto* cond = w->condition; | ||||
| 
 | ||||
|         // !condition
 | ||||
|         auto* not_cond = | ||||
|             ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond)); | ||||
|         auto* not_cond = b.Not(ctx.Clone(cond)); | ||||
| 
 | ||||
|         // { break; }
 | ||||
|         auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>()); | ||||
|         auto* break_body = b.Block(b.Break()); | ||||
| 
 | ||||
|         // if (!condition) { break; }
 | ||||
|         stmts.Push(ctx.dst->If(not_cond, break_body)); | ||||
|         stmts.Push(b.If(not_cond, break_body)); | ||||
| 
 | ||||
|         for (auto* stmt : w->body->statements) { | ||||
|             stmts.Push(ctx.Clone(stmt)); | ||||
| @ -55,13 +66,14 @@ void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { | ||||
| 
 | ||||
|         const ast::BlockStatement* continuing = nullptr; | ||||
| 
 | ||||
|         auto* body = ctx.dst->Block(stmts); | ||||
|         auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing); | ||||
|         auto* body = b.Block(stmts); | ||||
|         auto* loop = b.Loop(body, continuing); | ||||
| 
 | ||||
|         return loop; | ||||
|     }); | ||||
| 
 | ||||
|     ctx.Clone(); | ||||
|     return Program(std::move(b)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tint::transform
 | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user