diff --git a/src/tint/fuzzers/shuffle_transform.cc b/src/tint/fuzzers/shuffle_transform.cc index 5f5f6e61ac..6ae405a60a 100644 --- a/src/tint/fuzzers/shuffle_transform.cc +++ b/src/tint/fuzzers/shuffle_transform.cc @@ -15,6 +15,7 @@ #include "src/tint/fuzzers/shuffle_transform.h" #include +#include #include "src/tint/program_builder.h" @@ -22,15 +23,21 @@ namespace tint::fuzzers { ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {} -void ShuffleTransform::Run(CloneContext& ctx, - const tint::transform::DataMap&, - tint::transform::DataMap&) const { - auto decls = ctx.src->AST().GlobalDeclarations(); +transform::Transform::ApplyResult ShuffleTransform::Apply(const Program* src, + const transform::DataMap&, + transform::DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + auto decls = src->AST().GlobalDeclarations(); auto rng = std::mt19937_64{seed_}; std::shuffle(std::begin(decls), std::end(decls), rng); for (auto* decl : decls) { - ctx.dst->AST().AddGlobalDeclaration(ctx.Clone(decl)); + b.AST().AddGlobalDeclaration(ctx.Clone(decl)); } + + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::fuzzers diff --git a/src/tint/fuzzers/shuffle_transform.h b/src/tint/fuzzers/shuffle_transform.h index 0a64fe394c..ee54f973af 100644 --- a/src/tint/fuzzers/shuffle_transform.h +++ b/src/tint/fuzzers/shuffle_transform.h @@ -20,16 +20,16 @@ namespace tint::fuzzers { /// ShuffleTransform reorders the module scope declarations into a random order -class ShuffleTransform : public tint::transform::Transform { +class ShuffleTransform : public transform::Transform { public: /// Constructor /// @param seed the random seed to use for the shuffling explicit ShuffleTransform(size_t seed); - protected: - void Run(CloneContext& ctx, - const tint::transform::DataMap&, - tint::transform::DataMap&) const override; + /// @copydoc transform::Transform::Apply + ApplyResult Apply(const Program* program, + const transform::DataMap& inputs, + transform::DataMap& outputs) const override; private: size_t seed_; diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h index 8e271d7aea..634f5da7ef 100644 --- a/src/tint/sem/variable.h +++ b/src/tint/sem/variable.h @@ -23,6 +23,7 @@ #include "src/tint/ast/access.h" #include "src/tint/ast/address_space.h" +#include "src/tint/ast/parameter.h" #include "src/tint/sem/binding_point.h" #include "src/tint/sem/expression.h" #include "src/tint/sem/parameter_usage.h" @@ -212,6 +213,11 @@ class Parameter final : public Castable { /// Destructor ~Parameter() override; + /// @returns the AST declaration node + const ast::Parameter* Declaration() const { + return static_cast(Variable::Declaration()); + } + /// @return the index of the parmeter in the function uint32_t Index() const { return index_; } diff --git a/src/tint/transform/add_block_attribute.cc b/src/tint/transform/add_block_attribute.cc index 77d8719ff0..513925faa1 100644 --- a/src/tint/transform/add_block_attribute.cc +++ b/src/tint/transform/add_block_attribute.cc @@ -31,21 +31,29 @@ AddBlockAttribute::AddBlockAttribute() = default; AddBlockAttribute::~AddBlockAttribute() = default; -void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - auto& sem = ctx.src->Sem(); +Transform::ApplyResult AddBlockAttribute::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + auto& sem = src->Sem(); // A map from a type in the source program to a block-decorated wrapper that contains it in the // destination program. utils::Hashmap wrapper_structs; // Process global 'var' declarations that are buffers. - for (auto* global : ctx.src->AST().GlobalVariables()) { + bool made_changes = false; + for (auto* global : src->AST().GlobalVariables()) { auto* var = sem.Get(global); if (!ast::IsHostShareable(var->AddressSpace())) { // Not declared in a host-sharable address space continue; } + made_changes = true; + auto* ty = var->Type()->UnwrapRef(); auto* str = ty->As(); @@ -61,33 +69,36 @@ void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const { const char* kMemberName = "inner"; auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] { - auto* block = ctx.dst->ASTNodes().Create(ctx.dst->ID(), - ctx.dst->AllocateNodeID()); - auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block"; - auto* ret = ctx.dst->create( - ctx.dst->Symbols().New(wrapper_name), - utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))}, + auto* block = b.ASTNodes().Create(b.ID(), b.AllocateNodeID()); + auto wrapper_name = src->Symbols().NameFor(global->symbol) + "_block"; + auto* ret = b.create( + b.Symbols().New(wrapper_name), + utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))}, utils::Vector{block}); - ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret); + ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret); return ret; }); - ctx.Replace(global->type, ctx.dst->ty.Of(wrapper)); + ctx.Replace(global->type, b.ty.Of(wrapper)); // Insert a member accessor to get the original type from the wrapper at // any usage of the original variable. for (auto* user : var->Users()) { ctx.Replace(user->Declaration(), - ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName)); + b.MemberAccessor(ctx.Clone(global->symbol), kMemberName)); } } else { // Add a block attribute to this struct directly. - auto* block = ctx.dst->ASTNodes().Create(ctx.dst->ID(), - ctx.dst->AllocateNodeID()); + auto* block = b.ASTNodes().Create(b.ID(), b.AllocateNodeID()); ctx.InsertFront(str->Declaration()->attributes, block); } } + if (!made_changes) { + return SkipTransform; + } + ctx.Clone(); + return Program(std::move(b)); } AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid) diff --git a/src/tint/transform/add_block_attribute.h b/src/tint/transform/add_block_attribute.h index 2bfd63e750..d5e8c4e84c 100644 --- a/src/tint/transform/add_block_attribute.h +++ b/src/tint/transform/add_block_attribute.h @@ -53,14 +53,10 @@ class AddBlockAttribute final : public Castable { /// Destructor ~AddBlockAttribute() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/add_empty_entry_point.cc b/src/tint/transform/add_empty_entry_point.cc index 5ef4fe85eb..f71394dbf4 100644 --- a/src/tint/transform/add_empty_entry_point.cc +++ b/src/tint/transform/add_empty_entry_point.cc @@ -23,12 +23,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint); using namespace tint::number_suffixes; // NOLINT namespace tint::transform { +namespace { -AddEmptyEntryPoint::AddEmptyEntryPoint() = default; - -AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; - -bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { for (auto* func : program->AST().Functions()) { if (func->IsEntryPoint()) { return false; @@ -37,13 +34,30 @@ bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const return true; } -void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {}, - utils::Vector{ - ctx.dst->Stage(ast::PipelineStage::kCompute), - ctx.dst->WorkgroupSize(1_i), - }); +} // namespace + +AddEmptyEntryPoint::AddEmptyEntryPoint() = default; + +AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; + +Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {}, + utils::Vector{ + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1_i), + }); + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/add_empty_entry_point.h b/src/tint/transform/add_empty_entry_point.h index 553035504b..828f3b5222 100644 --- a/src/tint/transform/add_empty_entry_point.h +++ b/src/tint/transform/add_empty_entry_point.h @@ -27,19 +27,10 @@ class AddEmptyEntryPoint final : public Castable /// Destructor ~AddEmptyEntryPoint() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/array_length_from_uniform.cc b/src/tint/transform/array_length_from_uniform.cc index 3938b0ca4d..70097f237e 100644 --- a/src/tint/transform/array_length_from_uniform.cc +++ b/src/tint/transform/array_length_from_uniform.cc @@ -31,13 +31,153 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result); namespace tint::transform { +namespace { + +bool ShouldRun(const Program* program) { + for (auto* fn : program->AST().Functions()) { + if (auto* sem_fn = program->Sem().Get(fn)) { + for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { + if (builtin->Type() == sem::BuiltinType::kArrayLength) { + return true; + } + } + } + } + return false; +} + +} // namespace + ArrayLengthFromUniform::ArrayLengthFromUniform() = default; ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; -/// The PIMPL state for this transform +/// PIMPL state for the transform struct ArrayLengthFromUniform::State { + /// Constructor + /// @param program the source program + /// @param in the input transform data + /// @param out the output transform data + explicit State(const Program* program, const DataMap& in, DataMap& out) + : src(program), inputs(in), outputs(out) {} + + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { + auto* cfg = inputs.Get(); + if (cfg == nullptr) { + b.Diagnostics().add_error(diag::System::Transform, + "missing transform data for " + + std::string(TypeInfo::Of().name)); + return Program(std::move(b)); + } + + if (!ShouldRun(ctx.src)) { + return SkipTransform; + } + + const char* kBufferSizeMemberName = "buffer_size"; + + // Determine the size of the buffer size array. + uint32_t max_buffer_size_index = 0; + + IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*, + const sem::GlobalVariable* var) { + auto binding = var->BindingPoint(); + auto idx_itr = cfg->bindpoint_to_size_index.find(binding); + if (idx_itr == cfg->bindpoint_to_size_index.end()) { + return; + } + if (idx_itr->second > max_buffer_size_index) { + max_buffer_size_index = idx_itr->second; + } + }); + + // Get (or create, on first call) the uniform buffer that will receive the + // size of each storage buffer in the module. + const ast::Variable* buffer_size_ubo = nullptr; + auto get_ubo = [&]() { + if (!buffer_size_ubo) { + // Emit an array, N>, where N is 1/4 number of elements. + // We do this because UBOs require an element stride that is 16-byte + // aligned. + auto* buffer_size_struct = b.Structure( + b.Sym(), utils::Vector{ + b.Member(kBufferSizeMemberName, + b.ty.array(b.ty.vec4(b.ty.u32()), + u32((max_buffer_size_index / 4) + 1))), + }); + buffer_size_ubo = + b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct), ast::AddressSpace::kUniform, + b.Group(AInt(cfg->ubo_binding.group)), + b.Binding(AInt(cfg->ubo_binding.binding))); + } + return buffer_size_ubo; + }; + + std::unordered_set used_size_indices; + + IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, + const sem::VariableUser* storage_buffer_sem, + const sem::GlobalVariable* var) { + auto binding = var->BindingPoint(); + auto idx_itr = cfg->bindpoint_to_size_index.find(binding); + if (idx_itr == cfg->bindpoint_to_size_index.end()) { + return; + } + + uint32_t size_index = idx_itr->second; + used_size_indices.insert(size_index); + + // Load the total storage buffer size from the UBO. + uint32_t array_index = size_index / 4; + auto* vec_expr = b.IndexAccessor( + b.MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index)); + uint32_t vec_index = size_index % 4; + auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index)); + + // Calculate actual array length + // total_storage_buffer_size - array_offset + // array_length = ---------------------------------------- + // array_stride + const ast::Expression* total_size = total_storage_buffer_size; + auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); + const sem::Array* array_type = nullptr; + if (auto* str = storage_buffer_type->As()) { + // The variable is a struct, so subtract the byte offset of the array + // member. + auto* array_member_sem = str->Members().back(); + array_type = array_member_sem->Type()->As(); + total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset())); + } else if (auto* arr = storage_buffer_type->As()) { + array_type = arr; + } else { + TINT_ICE(Transform, b.Diagnostics()) + << "expected form of arrayLength argument to be &array_var or " + "&struct_var.array_member"; + return; + } + auto* array_length = b.Div(total_size, u32(array_type->Stride())); + + ctx.Replace(call_expr, array_length); + }); + + outputs.Add(used_size_indices); + + ctx.Clone(); + return Program(std::move(b)); + } + + private: + /// The source program + const Program* const src; + /// The transform inputs + const DataMap& inputs; + /// The transform outputs + DataMap& outputs; + /// The target program builder + ProgramBuilder b; /// The clone context - CloneContext& ctx; + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; /// Iterate over all arrayLength() builtins that operate on /// storage buffer variables. @@ -48,10 +188,10 @@ struct ArrayLengthFromUniform::State { /// sem::GlobalVariable for the storage buffer. template void IterateArrayLengthOnStorageVar(F&& functor) { - auto& sem = ctx.src->Sem(); + auto& sem = src->Sem(); // Find all calls to the arrayLength() builtin. - for (auto* node : ctx.src->ASTNodes().Objects()) { + for (auto* node : src->ASTNodes().Objects()) { auto* call_expr = node->As(); if (!call_expr) { continue; @@ -79,7 +219,7 @@ struct ArrayLengthFromUniform::State { // arrayLength(&array_var) auto* param = call_expr->args[0]->As(); if (!param || param->op != ast::UnaryOp::kAddressOf) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "expected form of arrayLength argument to be &array_var or " "&struct_var.array_member"; break; @@ -90,7 +230,7 @@ struct ArrayLengthFromUniform::State { } auto* storage_buffer_sem = sem.Get(storage_buffer_expr); if (!storage_buffer_sem) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "expected form of arrayLength argument to be &array_var or " "&struct_var.array_member"; break; @@ -99,8 +239,7 @@ struct ArrayLengthFromUniform::State { // Get the index to use for the buffer size array. auto* var = tint::As(storage_buffer_sem->Variable()); if (!var) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) - << "storage buffer is not a global variable"; + TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable"; break; } functor(call_expr, storage_buffer_sem, var); @@ -108,117 +247,10 @@ struct ArrayLengthFromUniform::State { } }; -bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const { - for (auto* fn : program->AST().Functions()) { - if (auto* sem_fn = program->Sem().Get(fn)) { - for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { - if (builtin->Type() == sem::BuiltinType::kArrayLength) { - return true; - } - } - } - } - return false; -} - -void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { - auto* cfg = inputs.Get(); - if (cfg == nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); - return; - } - - const char* kBufferSizeMemberName = "buffer_size"; - - // Determine the size of the buffer size array. - uint32_t max_buffer_size_index = 0; - - State{ctx}.IterateArrayLengthOnStorageVar( - [&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) { - auto binding = var->BindingPoint(); - auto idx_itr = cfg->bindpoint_to_size_index.find(binding); - if (idx_itr == cfg->bindpoint_to_size_index.end()) { - return; - } - if (idx_itr->second > max_buffer_size_index) { - max_buffer_size_index = idx_itr->second; - } - }); - - // Get (or create, on first call) the uniform buffer that will receive the - // size of each storage buffer in the module. - const ast::Variable* buffer_size_ubo = nullptr; - auto get_ubo = [&]() { - if (!buffer_size_ubo) { - // Emit an array, N>, where N is 1/4 number of elements. - // We do this because UBOs require an element stride that is 16-byte - // aligned. - auto* buffer_size_struct = ctx.dst->Structure( - ctx.dst->Sym(), - utils::Vector{ - ctx.dst->Member(kBufferSizeMemberName, - ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), - u32((max_buffer_size_index / 4) + 1))), - }); - buffer_size_ubo = ctx.dst->GlobalVar(ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), - ast::AddressSpace::kUniform, - ctx.dst->Group(AInt(cfg->ubo_binding.group)), - ctx.dst->Binding(AInt(cfg->ubo_binding.binding))); - } - return buffer_size_ubo; - }; - - std::unordered_set used_size_indices; - - State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, - const sem::VariableUser* storage_buffer_sem, - const sem::GlobalVariable* var) { - auto binding = var->BindingPoint(); - auto idx_itr = cfg->bindpoint_to_size_index.find(binding); - if (idx_itr == cfg->bindpoint_to_size_index.end()) { - return; - } - - uint32_t size_index = idx_itr->second; - used_size_indices.insert(size_index); - - // Load the total storage buffer size from the UBO. - uint32_t array_index = size_index / 4; - auto* vec_expr = ctx.dst->IndexAccessor( - ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index)); - uint32_t vec_index = size_index % 4; - auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, u32(vec_index)); - - // Calculate actual array length - // total_storage_buffer_size - array_offset - // array_length = ---------------------------------------- - // array_stride - const ast::Expression* total_size = total_storage_buffer_size; - auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); - const sem::Array* array_type = nullptr; - if (auto* str = storage_buffer_type->As()) { - // The variable is a struct, so subtract the byte offset of the array - // member. - auto* array_member_sem = str->Members().back(); - array_type = array_member_sem->Type()->As(); - total_size = ctx.dst->Sub(total_storage_buffer_size, u32(array_member_sem->Offset())); - } else if (auto* arr = storage_buffer_type->As()) { - array_type = arr; - } else { - TINT_ICE(Transform, ctx.dst->Diagnostics()) - << "expected form of arrayLength argument to be &array_var or " - "&struct_var.array_member"; - return; - } - auto* array_length = ctx.dst->Div(total_size, u32(array_type->Stride())); - - ctx.Replace(call_expr, array_length); - }); - - ctx.Clone(); - - outputs.Add(used_size_indices); +Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src, + const DataMap& inputs, + DataMap& outputs) const { + return State{src, inputs, outputs}.Run(); } ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {} diff --git a/src/tint/transform/array_length_from_uniform.h b/src/tint/transform/array_length_from_uniform.h index 8bd6af5f7f..507ea37294 100644 --- a/src/tint/transform/array_length_from_uniform.h +++ b/src/tint/transform/array_length_from_uniform.h @@ -100,22 +100,12 @@ class ArrayLengthFromUniform final : public Castable used_size_indices; }; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: - /// The PIMPL state for this transform struct State; }; diff --git a/src/tint/transform/array_length_from_uniform_test.cc b/src/tint/transform/array_length_from_uniform_test.cc index 1058bf1b17..b5d9e77a45 100644 --- a/src/tint/transform/array_length_from_uniform_test.cc +++ b/src/tint/transform/array_length_from_uniform_test.cc @@ -28,7 +28,13 @@ using ArrayLengthFromUniformTest = TransformTest; TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) { auto* src = R"()"; - EXPECT_FALSE(ShouldRun(src)); + ArrayLengthFromUniform::Config cfg({0, 30u}); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); + + DataMap data; + data.Add(std::move(cfg)); + + EXPECT_FALSE(ShouldRun(src, data)); } TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) { @@ -45,7 +51,13 @@ fn main() { } )"; - EXPECT_FALSE(ShouldRun(src)); + ArrayLengthFromUniform::Config cfg({0, 30u}); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); + + DataMap data; + data.Add(std::move(cfg)); + + EXPECT_FALSE(ShouldRun(src, data)); } TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) { @@ -63,7 +75,13 @@ fn main() { } )"; - EXPECT_TRUE(ShouldRun(src)); + ArrayLengthFromUniform::Config cfg({0, 30u}); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0); + + DataMap data; + data.Add(std::move(cfg)); + + EXPECT_TRUE(ShouldRun(src, data)); } TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { diff --git a/src/tint/transform/binding_remapper.cc b/src/tint/transform/binding_remapper.cc index 798b22803c..0781355f1e 100644 --- a/src/tint/transform/binding_remapper.cc +++ b/src/tint/transform/binding_remapper.cc @@ -40,19 +40,21 @@ BindingRemapper::Remappings::~Remappings() = default; BindingRemapper::BindingRemapper() = default; BindingRemapper::~BindingRemapper() = default; -bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const { - if (auto* remappings = inputs.Get()) { - return !remappings->binding_points.empty() || !remappings->access_controls.empty(); - } - return false; -} +Transform::ApplyResult BindingRemapper::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; -void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { auto* remappings = inputs.Get(); if (!remappings) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); - return; + b.Diagnostics().add_error(diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return Program(std::move(b)); + } + + if (remappings->binding_points.empty() && remappings->access_controls.empty()) { + return SkipTransform; } // A set of post-remapped binding points that need to be decorated with a @@ -62,11 +64,11 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co if (remappings->allow_collisions) { // Scan for binding point collisions generated by this transform. // Populate all collisions in the `add_collision_attr` set. - for (auto* func_ast : ctx.src->AST().Functions()) { + for (auto* func_ast : src->AST().Functions()) { if (!func_ast->IsEntryPoint()) { continue; } - auto* func = ctx.src->Sem().Get(func_ast); + auto* func = src->Sem().Get(func_ast); std::unordered_map binding_point_counts; for (auto* global : func->TransitivelyReferencedGlobals()) { if (global->Declaration()->HasBindingPoint()) { @@ -90,9 +92,9 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co } } - for (auto* var : ctx.src->AST().Globals()) { + for (auto* var : src->AST().Globals()) { if (var->HasBindingPoint()) { - auto* global_sem = ctx.src->Sem().Get(var); + auto* global_sem = src->Sem().Get(var); // The original binding point BindingPoint from = global_sem->BindingPoint(); @@ -106,8 +108,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co auto bp_it = remappings->binding_points.find(from); if (bp_it != remappings->binding_points.end()) { BindingPoint to = bp_it->second; - auto* new_group = ctx.dst->Group(AInt(to.group)); - auto* new_binding = ctx.dst->Binding(AInt(to.binding)); + auto* new_group = b.Group(AInt(to.group)); + auto* new_binding = b.Binding(AInt(to.binding)); auto* old_group = ast::GetAttribute(var->attributes); auto* old_binding = ast::GetAttribute(var->attributes); @@ -122,37 +124,37 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co if (ac_it != remappings->access_controls.end()) { ast::Access ac = ac_it->second; if (ac == ast::Access::kUndefined) { - ctx.dst->Diagnostics().add_error( + b.Diagnostics().add_error( diag::System::Transform, "invalid access mode (" + std::to_string(static_cast(ac)) + ")"); - return; + return Program(std::move(b)); } - auto* sem = ctx.src->Sem().Get(var); + auto* sem = src->Sem().Get(var); if (sem->AddressSpace() != ast::AddressSpace::kStorage) { - ctx.dst->Diagnostics().add_error( + b.Diagnostics().add_error( diag::System::Transform, "cannot apply access control to variable with address space " + std::string(utils::ToString(sem->AddressSpace()))); - return; + return Program(std::move(b)); } auto* ty = sem->Type()->UnwrapRef(); const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty); - auto* new_var = - ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, - var->declared_address_space, ac, ctx.Clone(var->initializer), - ctx.Clone(var->attributes)); + auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, + var->declared_address_space, ac, ctx.Clone(var->initializer), + ctx.Clone(var->attributes)); ctx.Replace(var, new_var); } // Add `DisableValidationAttribute`s if required if (add_collision_attr.count(bp)) { - auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision); + auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision); ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute); } } } ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/binding_remapper.h b/src/tint/transform/binding_remapper.h index 77fc5bce87..b0efe0dbe4 100644 --- a/src/tint/transform/binding_remapper.h +++ b/src/tint/transform/binding_remapper.h @@ -67,19 +67,10 @@ class BindingRemapper final : public Castable { BindingRemapper(); ~BindingRemapper() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/binding_remapper_test.cc b/src/tint/transform/binding_remapper_test.cc index 564a3a5b24..5bafb7e61f 100644 --- a/src/tint/transform/binding_remapper_test.cc +++ b/src/tint/transform/binding_remapper_test.cc @@ -23,12 +23,6 @@ namespace { using BindingRemapperTest = TransformTest; -TEST_F(BindingRemapperTest, ShouldRunNoRemappings) { - auto* src = R"()"; - - EXPECT_FALSE(ShouldRun(src)); -} - TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) { auto* src = R"()"; @@ -350,7 +344,7 @@ fn f() { } )"; - auto* expect = src; + auto* expect = R"(error: missing transform data for tint::transform::BindingRemapper)"; auto got = Run(src); diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc index 1e28e7e515..db200e4ce9 100644 --- a/src/tint/transform/builtin_polyfill.cc +++ b/src/tint/transform/builtin_polyfill.cc @@ -29,7 +29,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::BuiltinPolyfill::Config); namespace tint::transform { -/// The PIMPL state for the BuiltinPolyfill transform +/// PIMPL state for the transform struct BuiltinPolyfill::State { /// Constructor /// @param c the CloneContext @@ -604,193 +604,100 @@ BuiltinPolyfill::BuiltinPolyfill() = default; BuiltinPolyfill::~BuiltinPolyfill() = default; -bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const { - if (auto* cfg = data.Get()) { - auto builtins = cfg->builtins; - auto& sem = program->Sem(); - for (auto* node : program->ASTNodes().Objects()) { - if (auto* call = sem.Get(node)) { - if (auto* builtin = call->Target()->As()) { - if (call->Stage() == sem::EvaluationStage::kConstant) { - continue; // Don't polyfill @const expressions - } - switch (builtin->Type()) { - case sem::BuiltinType::kAcosh: - if (builtins.acosh != Level::kNone) { - return true; - } - break; - case sem::BuiltinType::kAsinh: - if (builtins.asinh) { - return true; - } - break; - case sem::BuiltinType::kAtanh: - if (builtins.atanh != Level::kNone) { - return true; - } - break; - case sem::BuiltinType::kClamp: - if (builtins.clamp_int) { - auto& sig = builtin->Signature(); - return sig.parameters[0]->Type()->is_integer_scalar_or_vector(); - } - break; - case sem::BuiltinType::kCountLeadingZeros: - if (builtins.count_leading_zeros) { - return true; - } - break; - case sem::BuiltinType::kCountTrailingZeros: - if (builtins.count_trailing_zeros) { - return true; - } - break; - case sem::BuiltinType::kExtractBits: - if (builtins.extract_bits != Level::kNone) { - return true; - } - break; - case sem::BuiltinType::kFirstLeadingBit: - if (builtins.first_leading_bit) { - return true; - } - break; - case sem::BuiltinType::kFirstTrailingBit: - if (builtins.first_trailing_bit) { - return true; - } - break; - case sem::BuiltinType::kInsertBits: - if (builtins.insert_bits != Level::kNone) { - return true; - } - break; - case sem::BuiltinType::kSaturate: - if (builtins.saturate) { - return true; - } - break; - case sem::BuiltinType::kTextureSampleBaseClampToEdge: - if (builtins.texture_sample_base_clamp_to_edge_2d_f32) { - auto& sig = builtin->Signature(); - auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); - if (auto* stex = tex->Type()->As()) { - return stex->type()->Is(); - } - } - break; - case sem::BuiltinType::kQuantizeToF16: - if (builtins.quantize_to_vec_f16) { - if (builtin->ReturnType()->Is()) { - return true; - } - } - break; - default: - break; - } - } - } - } - } - return false; -} - -void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const { +Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src, + const DataMap& data, + DataMap&) const { auto* cfg = data.Get(); if (!cfg) { - ctx.Clone(); - return; + return SkipTransform; } - std::unordered_map polyfills; + auto& builtins = cfg->builtins; - ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { - auto builtins = cfg->builtins; - State s{ctx, builtins}; - if (auto* call = s.sem.Get(expr)) { + utils::Hashmap polyfills; + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + State s{ctx, builtins}; + + bool made_changes = false; + for (auto* node : src->ASTNodes().Objects()) { + if (auto* call = src->Sem().Get(node)) { if (auto* builtin = call->Target()->As()) { if (call->Stage() == sem::EvaluationStage::kConstant) { - return nullptr; // Don't polyfill @const expressions + continue; // Don't polyfill @const expressions } Symbol polyfill; switch (builtin->Type()) { case sem::BuiltinType::kAcosh: if (builtins.acosh != Level::kNone) { - polyfill = utils::GetOrCreate( - polyfills, builtin, [&] { return s.acosh(builtin->ReturnType()); }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.acosh(builtin->ReturnType()); }); } break; case sem::BuiltinType::kAsinh: if (builtins.asinh) { - polyfill = utils::GetOrCreate( - polyfills, builtin, [&] { return s.asinh(builtin->ReturnType()); }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.asinh(builtin->ReturnType()); }); } break; case sem::BuiltinType::kAtanh: if (builtins.atanh != Level::kNone) { - polyfill = utils::GetOrCreate( - polyfills, builtin, [&] { return s.atanh(builtin->ReturnType()); }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.atanh(builtin->ReturnType()); }); } break; case sem::BuiltinType::kClamp: if (builtins.clamp_int) { auto& sig = builtin->Signature(); if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { - return s.clampInteger(builtin->ReturnType()); - }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.clampInteger(builtin->ReturnType()); }); } } break; case sem::BuiltinType::kCountLeadingZeros: if (builtins.count_leading_zeros) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { + polyfill = polyfills.GetOrCreate(builtin, [&] { return s.countLeadingZeros(builtin->ReturnType()); }); } break; case sem::BuiltinType::kCountTrailingZeros: if (builtins.count_trailing_zeros) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { + polyfill = polyfills.GetOrCreate(builtin, [&] { return s.countTrailingZeros(builtin->ReturnType()); }); } break; case sem::BuiltinType::kExtractBits: if (builtins.extract_bits != Level::kNone) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { - return s.extractBits(builtin->ReturnType()); - }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.extractBits(builtin->ReturnType()); }); } break; case sem::BuiltinType::kFirstLeadingBit: if (builtins.first_leading_bit) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { - return s.firstLeadingBit(builtin->ReturnType()); - }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); }); } break; case sem::BuiltinType::kFirstTrailingBit: if (builtins.first_trailing_bit) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { - return s.firstTrailingBit(builtin->ReturnType()); - }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); }); } break; case sem::BuiltinType::kInsertBits: if (builtins.insert_bits != Level::kNone) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { - return s.insertBits(builtin->ReturnType()); - }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.insertBits(builtin->ReturnType()); }); } break; case sem::BuiltinType::kSaturate: if (builtins.saturate) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { - return s.saturate(builtin->ReturnType()); - }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.saturate(builtin->ReturnType()); }); } break; case sem::BuiltinType::kTextureSampleBaseClampToEdge: @@ -799,7 +706,7 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); if (auto* stex = tex->Type()->As()) { if (stex->type()->Is()) { - polyfill = utils::GetOrCreate(polyfills, builtin, [&] { + polyfill = polyfills.GetOrCreate(builtin, [&] { return s.textureSampleBaseClampToEdge_2d_f32(); }); } @@ -809,8 +716,8 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons case sem::BuiltinType::kQuantizeToF16: if (builtins.quantize_to_vec_f16) { if (auto* vec = builtin->ReturnType()->As()) { - polyfill = utils::GetOrCreate(polyfills, builtin, - [&] { return s.quantizeToF16(vec); }); + polyfill = polyfills.GetOrCreate( + builtin, [&] { return s.quantizeToF16(vec); }); } } break; @@ -819,14 +726,20 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons break; } if (polyfill.IsValid()) { - return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args)); + auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args)); + ctx.Replace(call->Declaration(), replacement); + made_changes = true; } } } - return nullptr; - }); + } + + if (!made_changes) { + return SkipTransform; + } ctx.Clone(); + return Program(std::move(b)); } BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {} diff --git a/src/tint/transform/builtin_polyfill.h b/src/tint/transform/builtin_polyfill.h index 71069156a5..231d753571 100644 --- a/src/tint/transform/builtin_polyfill.h +++ b/src/tint/transform/builtin_polyfill.h @@ -87,21 +87,13 @@ class BuiltinPolyfill final : public Castable { const Builtins builtins; }; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; - protected: + private: struct State; - - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/builtin_polyfill_test.cc b/src/tint/transform/builtin_polyfill_test.cc index 65547abf4d..4bff0ff974 100644 --- a/src/tint/transform/builtin_polyfill_test.cc +++ b/src/tint/transform/builtin_polyfill_test.cc @@ -1561,7 +1561,8 @@ fn f() { TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) { auto* src = R"( fn f() { - let r : i32 = insertBits(1234, 5678, 5u, 6u); + let v = 1234i; + let r : i32 = insertBits(v, 5678, 5u, 6u); } )"; @@ -1975,10 +1976,6 @@ fn f() { )"; auto* expect = R"( -@group(0) @binding(0) var t : texture_2d; - -@group(0) @binding(1) var s : sampler; - fn tint_textureSampleBaseClampToEdge(t : texture_2d, s : sampler, coord : vec2) -> vec4 { let dims = vec2(textureDimensions(t, 0)); let half_texel = (vec2(0.5) / dims); @@ -1986,6 +1983,10 @@ fn tint_textureSampleBaseClampToEdge(t : texture_2d, s : sampler, coord : v return textureSampleLevel(t, s, clamped, 0); } +@group(0) @binding(0) var t : texture_2d; + +@group(0) @binding(1) var s : sampler; + fn f() { let r = tint_textureSampleBaseClampToEdge(t, s, vec2(0.5)); } diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc index 2ca5e5460d..9dcdd7b8eb 100644 --- a/src/tint/transform/calculate_array_length.cc +++ b/src/tint/transform/calculate_array_length.cc @@ -40,6 +40,19 @@ namespace tint::transform { namespace { +bool ShouldRun(const Program* program) { + for (auto* fn : program->AST().Functions()) { + if (auto* sem_fn = program->Sem().Get(fn)) { + for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { + if (builtin->Type() == sem::BuiltinType::kArrayLength) { + return true; + } + } + } + } + return false; +} + /// ArrayUsage describes a runtime array usage. /// It is used as a key by the array_length_by_usage map. struct ArrayUsage { @@ -73,21 +86,16 @@ const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSiz CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default; -bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const { - for (auto* fn : program->AST().Functions()) { - if (auto* sem_fn = program->Sem().Get(fn)) { - for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) { - if (builtin->Type() == sem::BuiltinType::kArrayLength) { - return true; - } - } - } +Transform::ApplyResult CalculateArrayLength::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; } - return false; -} -void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - auto& sem = ctx.src->Sem(); + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + auto& sem = src->Sem(); // get_buffer_size_intrinsic() emits the function decorated with // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to @@ -95,24 +103,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons std::unordered_map buffer_size_intrinsics; auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) { return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { - auto name = ctx.dst->Sym(); + auto name = b.Sym(); auto* type = CreateASTTypeFor(ctx, buffer_type); - auto* disable_validation = - ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter); - ctx.dst->AST().AddFunction(ctx.dst->create( + auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter); + b.AST().AddFunction(b.create( name, utils::Vector{ - ctx.dst->Param("buffer", - ctx.dst->ty.pointer(type, buffer_type->AddressSpace(), - buffer_type->Access()), - utils::Vector{disable_validation}), - ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(), - ast::AddressSpace::kFunction)), + b.Param("buffer", + b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()), + utils::Vector{disable_validation}), + b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)), }, - ctx.dst->ty.void_(), nullptr, + b.ty.void_(), nullptr, utils::Vector{ - ctx.dst->ASTNodes().Create(ctx.dst->ID(), - ctx.dst->AllocateNodeID()), + b.ASTNodes().Create(b.ID(), b.AllocateNodeID()), }, utils::Empty)); @@ -123,7 +127,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons std::unordered_map array_length_by_usage; // Find all the arrayLength() calls... - for (auto* node : ctx.src->ASTNodes().Objects()) { + for (auto* node : src->ASTNodes().Objects()) { if (auto* call_expr = node->As()) { auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As(); if (auto* builtin = call->Target()->As()) { @@ -149,7 +153,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons auto* arg = call_expr->args[0]; auto* address_of = arg->As(); if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "arrayLength() expected address-of, got " << arg->TypeInfo().name; } auto* storage_buffer_expr = address_of->expr; @@ -158,7 +162,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons } auto* storage_buffer_sem = sem.Get(storage_buffer_expr); if (!storage_buffer_sem) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "expected form of arrayLength argument to be &array_var or " "&struct_var.array_member"; break; @@ -179,25 +183,24 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons // Construct the variable that'll hold the result of // RWByteAddressBuffer.GetDimensions() - auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var( - ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u))); + auto* buffer_size_result = + b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u))); // Call storage_buffer.GetDimensions(&buffer_size_result) - auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call( + auto* call_get_dims = b.CallStmt(b.Call( // BufferSizeIntrinsic(X, ARGS...) is // translated to: // X.GetDimensions(ARGS..) by the writer - buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)), - ctx.dst->AddressOf( - ctx.dst->Expr(buffer_size_result->variable->symbol)))); + buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)), + b.AddressOf(b.Expr(buffer_size_result->variable->symbol)))); // Calculate actual array length // total_storage_buffer_size - array_offset // array_length = ---------------------------------------- // array_stride - auto name = ctx.dst->Sym(); + auto name = b.Sym(); const ast::Expression* total_size = - ctx.dst->Expr(buffer_size_result->variable); + b.Expr(buffer_size_result->variable); const sem::Array* array_type = Switch( storage_buffer_type->StoreType(), @@ -205,23 +208,21 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons // The variable is a struct, so subtract the byte offset of // the array member. auto* array_member_sem = str->Members().back(); - total_size = - ctx.dst->Sub(total_size, u32(array_member_sem->Offset())); + total_size = b.Sub(total_size, u32(array_member_sem->Offset())); return array_member_sem->Type()->As(); }, [&](const sem::Array* arr) { return arr; }); if (!array_type) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "expected form of arrayLength argument to be " "&array_var or &struct_var.array_member"; return name; } uint32_t array_stride = array_type->Size(); - auto* array_length_var = ctx.dst->Decl( - ctx.dst->Let(name, ctx.dst->ty.u32(), - ctx.dst->Div(total_size, u32(array_stride)))); + auto* array_length_var = b.Decl( + b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride)))); // Insert the array length calculations at the top of the block ctx.InsertBefore(block->statements, block->statements[0], @@ -234,13 +235,14 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons }); // Replace the call to arrayLength() with the array length variable - ctx.Replace(call_expr, ctx.dst->Expr(array_length)); + ctx.Replace(call_expr, b.Expr(array_length)); } } } } ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/calculate_array_length.h b/src/tint/transform/calculate_array_length.h index 8db8dcceca..e5714a85cd 100644 --- a/src/tint/transform/calculate_array_length.h +++ b/src/tint/transform/calculate_array_length.h @@ -59,19 +59,10 @@ class CalculateArrayLength final : public Castable attrs) { } // namespace -/// State holds the current transform state for a single entry point. +/// PIMPL state for the transform struct CanonicalizeEntryPointIO::State { /// OutputValue represents a shader result that the wrapper function produces. struct OutputValue { @@ -770,17 +770,22 @@ struct CanonicalizeEntryPointIO::State { } }; -void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { +Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + auto* cfg = inputs.Get(); if (cfg == nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); - return; + b.Diagnostics().add_error(diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return Program(std::move(b)); } // Remove entry point IO attributes from struct declarations. // New structures will be created for each entry point, as necessary. - for (auto* ty : ctx.src->AST().TypeDecls()) { + for (auto* ty : src->AST().TypeDecls()) { if (auto* struct_ty = ty->As()) { for (auto* member : struct_ty->members) { for (auto* attr : member->attributes) { @@ -792,7 +797,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat } } - for (auto* func_ast : ctx.src->AST().Functions()) { + for (auto* func_ast : src->AST().Functions()) { if (!func_ast->IsEntryPoint()) { continue; } @@ -802,6 +807,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat } ctx.Clone(); + return Program(std::move(b)); } CanonicalizeEntryPointIO::Config::Config(ShaderStyle style, diff --git a/src/tint/transform/canonicalize_entry_point_io.h b/src/tint/transform/canonicalize_entry_point_io.h index 95f8b197df..fbfed5ebd8 100644 --- a/src/tint/transform/canonicalize_entry_point_io.h +++ b/src/tint/transform/canonicalize_entry_point_io.h @@ -127,15 +127,12 @@ class CanonicalizeEntryPointIO final : public Castable +#include #include "src/tint/ast/attribute.h" #include "src/tint/ast/builtin_attribute.h" @@ -64,12 +64,7 @@ bool ReturnsFragDepthInStruct(const sem::Info& sem, const ast::Function* fn) { return false; } -} // anonymous namespace - -ClampFragDepth::ClampFragDepth() = default; -ClampFragDepth::~ClampFragDepth() = default; - -bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { auto& sem = program->Sem(); for (auto* fn : program->AST().Functions()) { @@ -82,22 +77,33 @@ bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const { return false; } -void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +} // anonymous namespace + +ClampFragDepth::ClampFragDepth() = default; +ClampFragDepth::~ClampFragDepth() = default; + +Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + // Abort on any use of push constants in the module. - for (auto* global : ctx.src->AST().GlobalVariables()) { + for (auto* global : src->AST().GlobalVariables()) { if (auto* var = global->As()) { if (var->declared_address_space == ast::AddressSpace::kPushConstant) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "ClampFragDepth doesn't know how to handle module that already use push " "constants."; - return; + return Program(std::move(b)); } } } - auto& b = *ctx.dst; - auto& sem = ctx.src->Sem(); - auto& sym = ctx.src->Symbols(); + if (!ShouldRun(src)) { + return SkipTransform; + } + + auto& sem = src->Sem(); + auto& sym = src->Symbols(); // At least one entry-point needs clamping. Add the following to the module: // @@ -197,6 +203,7 @@ void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const { }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/clamp_frag_depth.h b/src/tint/transform/clamp_frag_depth.h index 3b15f11c1f..1e9d0d6146 100644 --- a/src/tint/transform/clamp_frag_depth.h +++ b/src/tint/transform/clamp_frag_depth.h @@ -61,19 +61,10 @@ class ClampFragDepth final : public Castable { /// Destructor ~ClampFragDepth() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc index 97650ad0f2..e7286d412b 100644 --- a/src/tint/transform/combine_samplers.cc +++ b/src/tint/transform/combine_samplers.cc @@ -47,10 +47,14 @@ CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map, CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default; CombineSamplers::BindingInfo::~BindingInfo() = default; -/// The PIMPL state for the CombineSamplers transform +/// PIMPL state for the transform struct CombineSamplers::State { + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; /// The clone context - CloneContext& ctx; + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; /// The binding info const BindingInfo* binding_info; @@ -88,9 +92,9 @@ struct CombineSamplers::State { } /// Constructor - /// @param context the clone context + /// @param program the source program /// @param info the binding map information - State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {} + State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {} /// Creates a combined sampler global variables. /// (Note this is actually a Texture node at the AST level, but it will be @@ -145,8 +149,9 @@ struct CombineSamplers::State { } } - /// Performs the transformation - void Run() { + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { auto& sem = ctx.src->Sem(); // Remove all texture and sampler global variables. These will be replaced @@ -169,14 +174,14 @@ struct CombineSamplers::State { // Rewrite all function signatures to use combined samplers, and remove // separate textures & samplers. Create new combined globals where found. - ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* { - if (auto* func = sem.Get(src)) { - auto pairs = func->TextureSamplerPairs(); + ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* { + if (auto* fn = sem.Get(ast_fn)) { + auto pairs = fn->TextureSamplerPairs(); if (pairs.IsEmpty()) { return nullptr; } utils::Vector params; - for (auto pair : func->TextureSamplerPairs()) { + for (auto pair : fn->TextureSamplerPairs()) { const sem::Variable* texture_var = pair.first; const sem::Variable* sampler_var = pair.second; std::string name = @@ -197,23 +202,23 @@ struct CombineSamplers::State { auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var); auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type); params.Push(var); - function_combined_texture_samplers_[func][pair] = var; + function_combined_texture_samplers_[fn][pair] = var; } } // Filter out separate textures and samplers from the original // function signature. - for (auto* var : src->params) { - if (!sem.Get(var->type)->IsAnyOf()) { - params.Push(ctx.Clone(var)); + for (auto* param : fn->Parameters()) { + if (!param->Type()->IsAnyOf()) { + params.Push(ctx.Clone(param->Declaration())); } } // Create a new function signature that differs only in the parameter // list. - auto symbol = ctx.Clone(src->symbol); - auto* return_type = ctx.Clone(src->return_type); - auto* body = ctx.Clone(src->body); - auto attributes = ctx.Clone(src->attributes); - auto return_type_attributes = ctx.Clone(src->return_type_attributes); + auto symbol = ctx.Clone(ast_fn->symbol); + auto* return_type = ctx.Clone(ast_fn->return_type); + auto* body = ctx.Clone(ast_fn->body); + auto attributes = ctx.Clone(ast_fn->attributes); + auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes); return ctx.dst->create(symbol, params, return_type, body, std::move(attributes), std::move(return_type_attributes)); @@ -327,6 +332,7 @@ struct CombineSamplers::State { }); ctx.Clone(); + return Program(std::move(b)); } }; @@ -334,15 +340,18 @@ CombineSamplers::CombineSamplers() = default; CombineSamplers::~CombineSamplers() = default; -void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { +Transform::ApplyResult CombineSamplers::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { auto* binding_info = inputs.Get(); if (!binding_info) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); - return; + ProgramBuilder b; + b.Diagnostics().add_error(diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return Program(std::move(b)); } - State(ctx, binding_info).Run(); + return State(src, binding_info).Run(); } } // namespace tint::transform diff --git a/src/tint/transform/combine_samplers.h b/src/tint/transform/combine_samplers.h index 8dfc098792..6834abe77d 100644 --- a/src/tint/transform/combine_samplers.h +++ b/src/tint/transform/combine_samplers.h @@ -88,17 +88,13 @@ class CombineSamplers final : public Castable { /// Destructor ~CombineSamplers() override; - protected: - /// The PIMPL state for this transform - struct State; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + private: + struct State; }; } // namespace tint::transform diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc index 68324af188..046583e2a3 100644 --- a/src/tint/transform/decompose_memory_access.cc +++ b/src/tint/transform/decompose_memory_access.cc @@ -47,6 +47,18 @@ namespace tint::transform { namespace { +bool ShouldRun(const Program* program) { + for (auto* decl : program->AST().GlobalDeclarations()) { + if (auto* var = program->Sem().Get(decl)) { + if (var->AddressSpace() == ast::AddressSpace::kStorage || + var->AddressSpace() == ast::AddressSpace::kUniform) { + return true; + } + } + } + return false; +} + /// Offset is a simple ast::Expression builder interface, used to build byte /// offsets for storage and uniform buffer accesses. struct Offset : Castable { @@ -291,7 +303,7 @@ struct Store { } // namespace -/// State holds the current transform state +/// PIMPL state for the transform struct DecomposeMemoryAccess::State { /// The clone context CloneContext& ctx; @@ -477,7 +489,7 @@ struct DecomposeMemoryAccess::State { // * Override-expression counts can only be applied to workgroup arrays, and // this method only handles storage and uniform. // * Runtime-sized arrays are not loadable. - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; arr_cnt = 1; } @@ -578,7 +590,7 @@ struct DecomposeMemoryAccess::State { // * Override-expression counts can only be applied to workgroup // arrays, and this method only handles storage and uniform. // * Runtime-sized arrays are not storable. - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; arr_cnt = 1; } @@ -808,21 +820,16 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const { DecomposeMemoryAccess::DecomposeMemoryAccess() = default; DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; -bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const { - for (auto* decl : program->AST().GlobalDeclarations()) { - if (auto* var = program->Sem().Get(decl)) { - if (var->AddressSpace() == ast::AddressSpace::kStorage || - var->AddressSpace() == ast::AddressSpace::kUniform) { - return true; - } - } +Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; } - return false; -} - -void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - auto& sem = ctx.src->Sem(); + auto& sem = src->Sem(); + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; State state(ctx); // Scan the AST nodes for storage and uniform buffer accesses. Complex @@ -833,7 +840,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con // Inner-most expression nodes are guaranteed to be visited first because AST // nodes are fully immutable and require their children to be constructed // first so their pointer can be passed to the parent's initializer. - for (auto* node : ctx.src->ASTNodes().Objects()) { + for (auto* node : src->ASTNodes().Objects()) { if (auto* ident = node->As()) { // X if (auto* var = sem.Get(ident)) { @@ -1001,6 +1008,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con } ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/decompose_memory_access.h b/src/tint/transform/decompose_memory_access.h index 2e92a3a41c..21c196b864 100644 --- a/src/tint/transform/decompose_memory_access.h +++ b/src/tint/transform/decompose_memory_access.h @@ -108,20 +108,12 @@ class DecomposeMemoryAccess final : public Castable; -} // namespace - -DecomposeStridedArray::DecomposeStridedArray() = default; - -DecomposeStridedArray::~DecomposeStridedArray() = default; - -bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { for (auto* node : program->ASTNodes().Objects()) { if (auto* ast = node->As()) { if (ast::GetAttribute(ast->attributes)) { @@ -51,8 +45,22 @@ bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) co return false; } -void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - const auto& sem = ctx.src->Sem(); +} // namespace + +DecomposeStridedArray::DecomposeStridedArray() = default; + +DecomposeStridedArray::~DecomposeStridedArray() = default; + +Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + const auto& sem = src->Sem(); static constexpr const char* kMemberName = "el"; @@ -69,23 +77,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con if (auto* arr = sem.Get(ast)) { if (!arr->IsStrideImplicit()) { auto el_ty = utils::GetOrCreate(decomposed, arr, [&] { - auto name = ctx.dst->Symbols().New("strided_arr"); + auto name = b.Symbols().New("strided_arr"); auto* member_ty = ctx.Clone(ast->type); - auto* member = ctx.dst->Member(kMemberName, member_ty, - utils::Vector{ - ctx.dst->MemberSize(AInt(arr->Stride())), - }); - ctx.dst->Structure(name, utils::Vector{member}); + auto* member = b.Member(kMemberName, member_ty, + utils::Vector{ + b.MemberSize(AInt(arr->Stride())), + }); + b.Structure(name, utils::Vector{member}); return name; }); auto* count = ctx.Clone(ast->count); - return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count); + return b.ty.array(b.ty.type_name(el_ty), count); } if (ast::GetAttribute(ast->attributes)) { // Strip the @stride attribute auto* ty = ctx.Clone(ast->type); auto* count = ctx.Clone(ast->count); - return ctx.dst->ty.array(ty, count); + return b.ty.array(ty, count); } } return nullptr; @@ -96,11 +104,11 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con // to insert an additional member accessor for the single structure field. // Example: `arr[i]` -> `arr[i].el` ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* { - if (auto* ty = ctx.src->TypeOf(idx->object)) { + if (auto* ty = src->TypeOf(idx->object)) { if (auto* arr = ty->UnwrapRef()->As()) { if (!arr->IsStrideImplicit()) { auto* expr = ctx.CloneWithoutTransform(idx); - return ctx.dst->MemberAccessor(expr, kMemberName); + return b.MemberAccessor(expr, kMemberName); } } } @@ -136,21 +144,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con if (auto it = decomposed.find(arr); it != decomposed.end()) { args.Reserve(expr->args.Length()); for (auto* arg : expr->args) { - args.Push(ctx.dst->Call(it->second, ctx.Clone(arg))); + args.Push(b.Call(it->second, ctx.Clone(arg))); } } else { args = ctx.Clone(expr->args); } - return target.type ? ctx.dst->Construct(target.type, std::move(args)) - : ctx.dst->Call(target.name, std::move(args)); + return target.type ? b.Construct(target.type, std::move(args)) + : b.Call(target.name, std::move(args)); } } } } return nullptr; }); + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/decompose_strided_array.h b/src/tint/transform/decompose_strided_array.h index 5dbaaa5f5d..9555a9aa1d 100644 --- a/src/tint/transform/decompose_strided_array.h +++ b/src/tint/transform/decompose_strided_array.h @@ -35,19 +35,10 @@ class DecomposeStridedArray final : public Castable -void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { - for (auto* node : program->ASTNodes().Objects()) { +DecomposeStridedMatrix::DecomposeStridedMatrix() = default; + +DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; + +Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + // Scan the program for all storage and uniform structure matrix members with + // a custom stride attribute. Replace these matrices with an equivalent array, + // and populate the `decomposed` map with the members that have been replaced. + utils::Hashmap decomposed; + for (auto* node : src->ASTNodes().Objects()) { if (auto* str = node->As()) { - auto* str_ty = program->Sem().Get(str); + auto* str_ty = src->Sem().Get(str); if (!str_ty->UsedAs(ast::AddressSpace::kUniform) && !str_ty->UsedAs(ast::AddressSpace::kStorage)) { continue; @@ -89,46 +90,20 @@ void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { if (matrix->ColumnStride() == stride) { continue; } - if (callback(member, matrix, stride) == GatherResult::kStop) { - return; - } + // We've got ourselves a struct member of a matrix type with a custom + // stride. Replace this with an array of column vectors. + MatrixInfo info{stride, matrix}; + auto* replacement = + b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); + ctx.Replace(member->Declaration(), replacement); + decomposed.Add(member->Declaration(), info); } } } -} -} // namespace - -DecomposeStridedMatrix::DecomposeStridedMatrix() = default; - -DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; - -bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const { - bool should_run = false; - GatherCustomStrideMatrixMembers(program, - [&](const sem::StructMember*, const sem::Matrix*, uint32_t) { - should_run = true; - return GatherResult::kStop; - }); - return should_run; -} - -void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - // Scan the program for all storage and uniform structure matrix members with - // a custom stride attribute. Replace these matrices with an equivalent array, - // and populate the `decomposed` map with the members that have been replaced. - std::unordered_map decomposed; - GatherCustomStrideMatrixMembers( - ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) { - // We've got ourselves a struct member of a matrix type with a custom - // stride. Replace this with an array of column vectors. - MatrixInfo info{stride, matrix}; - auto* replacement = - ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); - ctx.Replace(member->Declaration(), replacement); - decomposed.emplace(member->Declaration(), info); - return GatherResult::kContinue; - }); + if (decomposed.IsEmpty()) { + return SkipTransform; + } // For all expressions where a single matrix column vector was indexed, we can // preserve these without calling conversion functions. @@ -136,12 +111,11 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co // ssbo.mat[2] -> ssbo.mat[2] ctx.ReplaceAll( [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { - if (auto* access = ctx.src->Sem().Get(expr->object)) { - auto it = decomposed.find(access->Member()->Declaration()); - if (it != decomposed.end()) { + if (auto* access = src->Sem().Get(expr->object)) { + if (decomposed.Contains(access->Member()->Declaration())) { auto* obj = ctx.CloneWithoutTransform(expr->object); auto* idx = ctx.Clone(expr->index); - return ctx.dst->IndexAccessor(obj, idx); + return b.IndexAccessor(obj, idx); } } return nullptr; @@ -154,39 +128,36 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co // ssbo.mat = mat_to_arr(m) std::unordered_map mat_to_arr; ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { - if (auto* access = ctx.src->Sem().Get(stmt->lhs)) { - auto it = decomposed.find(access->Member()->Declaration()); - if (it == decomposed.end()) { - return nullptr; + if (auto* access = src->Sem().Get(stmt->lhs)) { + if (auto* info = decomposed.Find(access->Member()->Declaration())) { + auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { + auto name = + b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" + + std::to_string(info->matrix->rows()) + "_stride_" + + std::to_string(info->stride) + "_to_arr"); + + auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; + auto array = [&] { return info->array(ctx.dst); }; + + auto mat = b.Sym("m"); + utils::Vector columns; + for (uint32_t i = 0; i < static_cast(info->matrix->columns()); i++) { + columns.Push(b.IndexAccessor(mat, u32(i))); + } + b.Func(name, + utils::Vector{ + b.Param(mat, matrix()), + }, + array(), + utils::Vector{ + b.Return(b.Construct(array(), columns)), + }); + return name; + }); + auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); + auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs)); + return b.Assign(lhs, rhs); } - MatrixInfo info = it->second; - auto fn = utils::GetOrCreate(mat_to_arr, info, [&] { - auto name = - ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" + - std::to_string(info.matrix->rows()) + "_stride_" + - std::to_string(info.stride) + "_to_arr"); - - auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; - auto array = [&] { return info.array(ctx.dst); }; - - auto mat = ctx.dst->Sym("m"); - utils::Vector columns; - for (uint32_t i = 0; i < static_cast(info.matrix->columns()); i++) { - columns.Push(ctx.dst->IndexAccessor(mat, u32(i))); - } - ctx.dst->Func(name, - utils::Vector{ - ctx.dst->Param(mat, matrix()), - }, - array(), - utils::Vector{ - ctx.dst->Return(ctx.dst->Construct(array(), columns)), - }); - return name; - }); - auto* lhs = ctx.CloneWithoutTransform(stmt->lhs); - auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs)); - return ctx.dst->Assign(lhs, rhs); } return nullptr; }); @@ -196,41 +167,40 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co // m = arr_to_mat(ssbo.mat) std::unordered_map arr_to_mat; ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { - if (auto* access = ctx.src->Sem().Get(expr)) { - auto it = decomposed.find(access->Member()->Declaration()); - if (it == decomposed.end()) { - return nullptr; + if (auto* access = src->Sem().Get(expr)) { + if (auto* info = decomposed.Find(access->Member()->Declaration())) { + auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { + auto name = + b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) + + "x" + std::to_string(info->matrix->rows()) + "_stride_" + + std::to_string(info->stride)); + + auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); }; + auto array = [&] { return info->array(ctx.dst); }; + + auto arr = b.Sym("arr"); + utils::Vector columns; + for (uint32_t i = 0; i < static_cast(info->matrix->columns()); i++) { + columns.Push(b.IndexAccessor(arr, u32(i))); + } + b.Func(name, + utils::Vector{ + b.Param(arr, array()), + }, + matrix(), + utils::Vector{ + b.Return(b.Construct(matrix(), columns)), + }); + return name; + }); + return b.Call(fn, ctx.CloneWithoutTransform(expr)); } - MatrixInfo info = it->second; - auto fn = utils::GetOrCreate(arr_to_mat, info, [&] { - auto name = ctx.dst->Symbols().New( - "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" + - std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride)); - - auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; - auto array = [&] { return info.array(ctx.dst); }; - - auto arr = ctx.dst->Sym("arr"); - utils::Vector columns; - for (uint32_t i = 0; i < static_cast(info.matrix->columns()); i++) { - columns.Push(ctx.dst->IndexAccessor(arr, u32(i))); - } - ctx.dst->Func(name, - utils::Vector{ - ctx.dst->Param(arr, array()), - }, - matrix(), - utils::Vector{ - ctx.dst->Return(ctx.dst->Construct(matrix(), columns)), - }); - return name; - }); - return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr)); } return nullptr; }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/decompose_strided_matrix.h b/src/tint/transform/decompose_strided_matrix.h index 40e9c3e237..947dfc653f 100644 --- a/src/tint/transform/decompose_strided_matrix.h +++ b/src/tint/transform/decompose_strided_matrix.h @@ -35,19 +35,10 @@ class DecomposeStridedMatrix final : public CastableSem().Module()->Extensions().Contains( - ast::Extension::kChromiumDisableUniformityAnalysis); -} +Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (src->Sem().Module()->Extensions().Contains( + ast::Extension::kChromiumDisableUniformityAnalysis)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + b.Enable(ast::Extension::kChromiumDisableUniformityAnalysis); -void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/disable_uniformity_analysis.h b/src/tint/transform/disable_uniformity_analysis.h index 3c9fb53743..a9922af989 100644 --- a/src/tint/transform/disable_uniformity_analysis.h +++ b/src/tint/transform/disable_uniformity_analysis.h @@ -27,19 +27,10 @@ class DisableUniformityAnalysis final : public CastableASTNodes().Objects()) { if (node->IsAnyOf()) { return true; @@ -44,21 +42,10 @@ bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) return false; } -namespace { +} // namespace -/// Internal class used to collect statement expansions during the transform. -class State { - private: - /// The clone context. - CloneContext& ctx; - - /// The program builder. - ProgramBuilder& b; - - /// The HoistToDeclBefore helper instance. - HoistToDeclBefore hoist_to_decl_before; - - public: +/// PIMPL state for the transform +struct ExpandCompoundAssignment::State { /// Constructor /// @param context the clone context explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {} @@ -158,15 +145,32 @@ class State { ctx.Replace(stmt, b.Assign(new_lhs(), value)); } - /// Finalize the transformation and clone the module. - void Finalize() { ctx.Clone(); } + private: + /// The clone context. + CloneContext& ctx; + + /// The program builder. + ProgramBuilder& b; + + /// The HoistToDeclBefore helper instance. + HoistToDeclBefore hoist_to_decl_before; }; -} // namespace +ExpandCompoundAssignment::ExpandCompoundAssignment() = default; -void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +ExpandCompoundAssignment::~ExpandCompoundAssignment() = default; + +Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; State state(ctx); - for (auto* node : ctx.src->ASTNodes().Objects()) { + for (auto* node : src->ASTNodes().Objects()) { if (auto* assign = node->As()) { state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op); } else if (auto* inc_dec = node->As()) { @@ -175,7 +179,9 @@ void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op); } } - state.Finalize(); + + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/expand_compound_assignment.h b/src/tint/transform/expand_compound_assignment.h index 1081df7b32..6b299c5216 100644 --- a/src/tint/transform/expand_compound_assignment.h +++ b/src/tint/transform/expand_compound_assignment.h @@ -45,19 +45,13 @@ class ExpandCompoundAssignment final : public CastableAST().Functions()) { + if (fn->PipelineStage() == ast::PipelineStage::kVertex) { + return true; + } + } + return false; +} + } // namespace FirstIndexOffset::BindingPoint::BindingPoint() = default; @@ -49,16 +58,16 @@ FirstIndexOffset::Data::~Data() = default; FirstIndexOffset::FirstIndexOffset() = default; FirstIndexOffset::~FirstIndexOffset() = default; -bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const { - for (auto* fn : program->AST().Functions()) { - if (fn->PipelineStage() == ast::PipelineStage::kVertex) { - return true; - } +Transform::ApplyResult FirstIndexOffset::Apply(const Program* src, + const DataMap& inputs, + DataMap& outputs) const { + if (!ShouldRun(src)) { + return SkipTransform; } - return false; -} -void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + // Get the uniform buffer binding point uint32_t ub_binding = binding_; uint32_t ub_group = group_; @@ -115,17 +124,17 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou if (has_vertex_or_instance_index) { // Add uniform buffer members and calculate byte offsets utils::Vector members; - members.Push(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32())); - members.Push(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32())); - auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members)); + members.Push(b.Member(kFirstVertexName, b.ty.u32())); + members.Push(b.Member(kFirstInstanceName, b.ty.u32())); + auto* struct_ = b.Structure(b.Sym(), std::move(members)); // Create a global to hold the uniform buffer - Symbol buffer_name = ctx.dst->Sym(); - ctx.dst->GlobalVar(buffer_name, ctx.dst->ty.Of(struct_), ast::AddressSpace::kUniform, - utils::Vector{ - ctx.dst->Binding(AInt(ub_binding)), - ctx.dst->Group(AInt(ub_group)), - }); + Symbol buffer_name = b.Sym(); + b.GlobalVar(buffer_name, b.ty.Of(struct_), ast::AddressSpace::kUniform, + utils::Vector{ + b.Binding(AInt(ub_binding)), + b.Group(AInt(ub_group)), + }); // Fix up all references to the builtins with the offsets ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* { @@ -150,9 +159,10 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou }); } - ctx.Clone(); - outputs.Add(has_vertex_or_instance_index); + + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/first_index_offset.h b/src/tint/transform/first_index_offset.h index 04758cde8d..f84d8112da 100644 --- a/src/tint/transform/first_index_offset.h +++ b/src/tint/transform/first_index_offset.h @@ -103,19 +103,10 @@ class FirstIndexOffset final : public Castable { /// Destructor ~FirstIndexOffset() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: uint32_t binding_ = 0; diff --git a/src/tint/transform/for_loop_to_loop.cc b/src/tint/transform/for_loop_to_loop.cc index e585790e40..63ccb12975 100644 --- a/src/tint/transform/for_loop_to_loop.cc +++ b/src/tint/transform/for_loop_to_loop.cc @@ -14,17 +14,17 @@ #include "src/tint/transform/for_loop_to_loop.h" +#include + #include "src/tint/ast/break_statement.h" #include "src/tint/program_builder.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop); namespace tint::transform { -ForLoopToLoop::ForLoopToLoop() = default; +namespace { -ForLoopToLoop::~ForLoopToLoop() = default; - -bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { for (auto* node : program->ASTNodes().Objects()) { if (node->Is()) { return true; @@ -33,19 +33,31 @@ bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { return false; } -void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +} // namespace + +ForLoopToLoop::ForLoopToLoop() = default; + +ForLoopToLoop::~ForLoopToLoop() = default; + +Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { utils::Vector stmts; if (auto* cond = for_loop->condition) { // !condition - auto* not_cond = - ctx.dst->create(ast::UnaryOp::kNot, ctx.Clone(cond)); + auto* not_cond = b.Not(ctx.Clone(cond)); // { break; } - auto* break_body = ctx.dst->Block(ctx.dst->create()); + auto* break_body = b.Block(b.Break()); // if (!condition) { break; } - stmts.Push(ctx.dst->If(not_cond, break_body)); + stmts.Push(b.If(not_cond, break_body)); } for (auto* stmt : for_loop->body->statements) { stmts.Push(ctx.Clone(stmt)); @@ -53,20 +65,21 @@ void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { const ast::BlockStatement* continuing = nullptr; if (auto* cont = for_loop->continuing) { - continuing = ctx.dst->Block(ctx.Clone(cont)); + continuing = b.Block(ctx.Clone(cont)); } - auto* body = ctx.dst->Block(stmts); - auto* loop = ctx.dst->create(body, continuing); + auto* body = b.Block(stmts); + auto* loop = b.Loop(body, continuing); if (auto* init = for_loop->initializer) { - return ctx.dst->Block(ctx.Clone(init), loop); + return b.Block(ctx.Clone(init), loop); } return loop; }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/for_loop_to_loop.h b/src/tint/transform/for_loop_to_loop.h index 5ab690a567..fe3db97482 100644 --- a/src/tint/transform/for_loop_to_loop.h +++ b/src/tint/transform/for_loop_to_loop.h @@ -29,19 +29,10 @@ class ForLoopToLoop final : public Castable { /// Destructor ~ForLoopToLoop() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/localize_struct_array_assignment.cc b/src/tint/transform/localize_struct_array_assignment.cc index 80773931d4..bfe8865ffa 100644 --- a/src/tint/transform/localize_struct_array_assignment.cc +++ b/src/tint/transform/localize_struct_array_assignment.cc @@ -32,70 +32,15 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment); namespace tint::transform { -/// Private implementation of LocalizeStructArrayAssignment transform -class LocalizeStructArrayAssignment::State { - private: - CloneContext& ctx; - ProgramBuilder& b; - - /// Returns true if `expr` contains an index accessor expression to a - /// structure member of array type. - bool ContainsStructArrayIndex(const ast::Expression* expr) { - bool result = false; - ast::TraverseExpressions( - expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { - // Indexing using a runtime value? - auto* idx_sem = ctx.src->Sem().Get(ia->index); - if (!idx_sem->ConstantValue()) { - // Indexing a member access expr? - if (auto* ma = ia->object->As()) { - // That accesses an array? - if (ctx.src->TypeOf(ma)->UnwrapRef()->Is()) { - result = true; - return ast::TraverseAction::Stop; - } - } - } - return ast::TraverseAction::Descend; - }); - - return result; - } - - // Returns the type and address space of the originating variable of the lhs - // of the assignment statement. - // See https://www.w3.org/TR/WGSL/#originating-variable-section - std::pair GetOriginatingTypeAndAddressSpace( - const ast::AssignmentStatement* assign_stmt) { - auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable(); - if (!source_var) { - TINT_ICE(Transform, b.Diagnostics()) - << "Unable to determine originating variable for lhs of assignment " - "statement"; - return {}; - } - - auto* type = source_var->Type(); - if (auto* ref = type->As()) { - return {ref->StoreType(), ref->AddressSpace()}; - } else if (auto* ptr = type->As()) { - return {ptr->StoreType(), ptr->AddressSpace()}; - } - - TINT_ICE(Transform, b.Diagnostics()) - << "Expecting to find variable of type pointer or reference on lhs " - "of assignment statement"; - return {}; - } - - public: +/// PIMPL state for the transform +struct LocalizeStructArrayAssignment::State { /// Constructor - /// @param ctx_in the CloneContext primed with the input program and - /// ProgramBuilder - explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} + /// @param program the source program + explicit State(const Program* program) : src(program) {} /// Runs the transform - void Run() { + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { struct Shared { bool process_nested_nodes = false; utils::Vector insert_before_stmts; @@ -189,6 +134,65 @@ class LocalizeStructArrayAssignment::State { }); ctx.Clone(); + return Program(std::move(b)); + } + + private: + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; + /// The clone context + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; + + /// Returns true if `expr` contains an index accessor expression to a + /// structure member of array type. + bool ContainsStructArrayIndex(const ast::Expression* expr) { + bool result = false; + ast::TraverseExpressions( + expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) { + // Indexing using a runtime value? + auto* idx_sem = src->Sem().Get(ia->index); + if (!idx_sem->ConstantValue()) { + // Indexing a member access expr? + if (auto* ma = ia->object->As()) { + // That accesses an array? + if (src->TypeOf(ma)->UnwrapRef()->Is()) { + result = true; + return ast::TraverseAction::Stop; + } + } + } + return ast::TraverseAction::Descend; + }); + + return result; + } + + // Returns the type and address space of the originating variable of the lhs + // of the assignment statement. + // See https://www.w3.org/TR/WGSL/#originating-variable-section + std::pair GetOriginatingTypeAndAddressSpace( + const ast::AssignmentStatement* assign_stmt) { + auto* source_var = src->Sem().Get(assign_stmt->lhs)->SourceVariable(); + if (!source_var) { + TINT_ICE(Transform, b.Diagnostics()) + << "Unable to determine originating variable for lhs of assignment " + "statement"; + return {}; + } + + auto* type = source_var->Type(); + if (auto* ref = type->As()) { + return {ref->StoreType(), ref->AddressSpace()}; + } else if (auto* ptr = type->As()) { + return {ptr->StoreType(), ptr->AddressSpace()}; + } + + TINT_ICE(Transform, b.Diagnostics()) + << "Expecting to find variable of type pointer or reference on lhs " + "of assignment statement"; + return {}; } }; @@ -196,9 +200,10 @@ LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default; LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default; -void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State state(ctx); - state.Run(); +Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src, + const DataMap&, + DataMap&) const { + return State{src}.Run(); } } // namespace tint::transform diff --git a/src/tint/transform/localize_struct_array_assignment.h b/src/tint/transform/localize_struct_array_assignment.h index 130f8cc107..169e33ca7a 100644 --- a/src/tint/transform/localize_struct_array_assignment.h +++ b/src/tint/transform/localize_struct_array_assignment.h @@ -36,17 +36,13 @@ class LocalizeStructArrayAssignment final /// Destructor ~LocalizeStructArrayAssignment() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: - class State; + struct State; }; } // namespace tint::transform diff --git a/src/tint/transform/manager.cc b/src/tint/transform/manager.cc index 4e83320b41..79603c80e2 100644 --- a/src/tint/transform/manager.cc +++ b/src/tint/transform/manager.cc @@ -31,9 +31,9 @@ namespace tint::transform { Manager::Manager() = default; Manager::~Manager() = default; -Output Manager::Run(const Program* program, const DataMap& data) const { - const Program* in = program; - +Transform::ApplyResult Manager::Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const { #if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM auto print_program = [&](const char* msg, const Transform* transform) { auto wgsl = Program::printer(in); @@ -46,34 +46,30 @@ Output Manager::Run(const Program* program, const DataMap& data) const { }; #endif - Output out; + std::optional output; + for (const auto& transform : transforms_) { - if (!transform->ShouldRun(in, data)) { - TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name - << std::endl); - continue; - } TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get())); - auto res = transform->Run(in, data); - out.program = std::move(res.program); - out.data.Add(std::move(res.data)); - in = &out.program; - if (!in->IsValid()) { - TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); - return out; - } + if (auto result = transform->Apply(program, inputs, outputs)) { + output.emplace(std::move(result.value())); + program = &output.value(); - if (transform == transforms_.back()) { - TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); + if (!program->IsValid()) { + TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); + break; + } + + if (transform == transforms_.back()) { + TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); + } + } else { + TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name + << std::endl); } } - if (program == in) { - out.program = program->Clone(); - } - - return out; + return output; } } // namespace tint::transform diff --git a/src/tint/transform/manager.h b/src/tint/transform/manager.h index 9d4049f719..64ca847034 100644 --- a/src/tint/transform/manager.h +++ b/src/tint/transform/manager.h @@ -47,11 +47,10 @@ class Manager final : public Castable { transforms_.emplace_back(std::make_unique(std::forward(args)...)); } - /// Runs the transforms on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformed program and diagnostics - Output Run(const Program* program, const DataMap& data = {}) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: std::vector> transforms_; diff --git a/src/tint/transform/merge_return.cc b/src/tint/transform/merge_return.cc index aec6b6d764..2b45b73c92 100644 --- a/src/tint/transform/merge_return.cc +++ b/src/tint/transform/merge_return.cc @@ -65,15 +65,6 @@ MergeReturn::MergeReturn() = default; MergeReturn::~MergeReturn() = default; -bool MergeReturn::ShouldRun(const Program* program, const DataMap&) const { - for (auto* func : program->AST().Functions()) { - if (NeedsTransform(program, func)) { - return true; - } - } - return false; -} - namespace { /// Internal class used to during the transform. @@ -223,7 +214,12 @@ class State { } // namespace -void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + bool made_changes = false; + for (auto* func : ctx.src->AST().Functions()) { if (!NeedsTransform(ctx.src, func)) { continue; @@ -231,9 +227,15 @@ void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const { State state(ctx, func); state.ProcessStatement(func->body); + made_changes = true; + } + + if (!made_changes) { + return SkipTransform; } ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/merge_return.h b/src/tint/transform/merge_return.h index 1334a5cdc5..f6db5c2ec9 100644 --- a/src/tint/transform/merge_return.h +++ b/src/tint/transform/merge_return.h @@ -27,19 +27,10 @@ class MergeReturn final : public Castable { /// Destructor ~MergeReturn() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc index f9c11e50ad..16a622e891 100644 --- a/src/tint/transform/module_scope_var_to_entry_point_param.cc +++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc @@ -38,6 +38,15 @@ using WorkgroupParameterMemberList = utils::Vector; // The name of the struct member for arrays that are wrapped in structures. const char* kWrappedArrayMemberName = "arr"; +bool ShouldRun(const Program* program) { + for (auto* decl : program->AST().GlobalDeclarations()) { + if (decl->Is()) { + return true; + } + } + return false; +} + // Returns `true` if `type` is or contains a matrix type. bool ContainsMatrix(const sem::Type* type) { type = type->UnwrapRef(); @@ -56,7 +65,7 @@ bool ContainsMatrix(const sem::Type* type) { } } // namespace -/// State holds the current transform state. +/// PIMPL state for the transform struct ModuleScopeVarToEntryPointParam::State { /// The clone context. CloneContext& ctx; @@ -501,19 +510,20 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default; ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; -bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const { - for (auto* decl : program->AST().GlobalDeclarations()) { - if (decl->Is()) { - return true; - } +Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; } - return false; -} -void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; State state{ctx}; state.Process(); + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.h b/src/tint/transform/module_scope_var_to_entry_point_param.h index 75bdaf3682..377151f854 100644 --- a/src/tint/transform/module_scope_var_to_entry_point_param.h +++ b/src/tint/transform/module_scope_var_to_entry_point_param.h @@ -69,20 +69,12 @@ class ModuleScopeVarToEntryPointParam final /// Destructor ~ModuleScopeVarToEntryPointParam() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; + private: struct State; }; diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc index 002b8580db..c3ebf4a124 100644 --- a/src/tint/transform/multiplanar_external_texture.cc +++ b/src/tint/transform/multiplanar_external_texture.cc @@ -31,6 +31,17 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::transform { namespace { +bool ShouldRun(const Program* program) { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* ty = node->As()) { + if (program->Sem().Get(ty)) { + return true; + } + } + } + return false; +} + /// This struct stores symbols for new bindings created as a result of transforming a /// texture_external instance. struct NewBindingSymbols { @@ -40,7 +51,7 @@ struct NewBindingSymbols { }; } // namespace -/// State holds the current transform state +/// PIMPL state for the transform struct MultiplanarExternalTexture::State { /// The clone context. CloneContext& ctx; @@ -537,30 +548,26 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default; MultiplanarExternalTexture::MultiplanarExternalTexture() = default; MultiplanarExternalTexture::~MultiplanarExternalTexture() = default; -bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const { - for (auto* node : program->ASTNodes().Objects()) { - if (auto* ty = node->As()) { - if (program->Sem().Get(ty)) { - return true; - } - } - } - return false; -} - // Within this transform, an instance of a texture_external binding is unpacked into two // texture_2d bindings representing two possible planes of a single texture and a uniform // buffer binding representing a struct of parameters. Calls to texture builtins that contain a // texture_external parameter will be transformed into a newly generated version of the function, // which can perform the desired operation on a single RGBA plane or on separate Y and UV planes. -void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { +Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { auto* new_binding_points = inputs.Get(); + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; if (!new_binding_points) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "missing new binding point data for " + std::string(TypeInfo().name)); - return; + b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " + + std::string(TypeInfo().name)); + return Program(std::move(b)); } State state(ctx, new_binding_points); @@ -568,6 +575,7 @@ void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, D state.Process(); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/multiplanar_external_texture.h b/src/tint/transform/multiplanar_external_texture.h index a10fed4c09..695e38c6b4 100644 --- a/src/tint/transform/multiplanar_external_texture.h +++ b/src/tint/transform/multiplanar_external_texture.h @@ -80,21 +80,13 @@ class MultiplanarExternalTexture final : public Castable(src)); + DataMap data; + data.Add( + MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); + + EXPECT_FALSE(ShouldRun(src, data)); } TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { @@ -31,14 +35,22 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { type ET = texture_external; )"; - EXPECT_TRUE(ShouldRun(src)); + DataMap data; + data.Add( + MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); + + EXPECT_TRUE(ShouldRun(src, data)); } TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) { auto* src = R"( @group(0) @binding(0) var ext_tex : texture_external; )"; - EXPECT_TRUE(ShouldRun(src)); + DataMap data; + data.Add( + MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); + + EXPECT_TRUE(ShouldRun(src, data)); } TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { @@ -46,7 +58,11 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { fn f(ext_tex : texture_external) {} )"; - EXPECT_TRUE(ShouldRun(src)); + DataMap data; + data.Add( + MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); + + EXPECT_TRUE(ShouldRun(src, data)); } // Running the transform without passing in data for the new bindings should result in an error. diff --git a/src/tint/transform/num_workgroups_from_uniform.cc b/src/tint/transform/num_workgroups_from_uniform.cc index 2122f072f3..e6681ca006 100644 --- a/src/tint/transform/num_workgroups_from_uniform.cc +++ b/src/tint/transform/num_workgroups_from_uniform.cc @@ -29,6 +29,18 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config); namespace tint::transform { namespace { + +bool ShouldRun(const Program* program) { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* attr = node->As()) { + if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) { + return true; + } + } + } + return false; +} + /// Accessor describes the identifiers used in a member accessor that is being /// used to retrieve the num_workgroups builtin from a parameter. struct Accessor { @@ -44,41 +56,40 @@ struct Accessor { size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); } }; }; + } // namespace NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default; NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default; -bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const { - for (auto* node : program->ASTNodes().Objects()) { - if (auto* attr = node->As()) { - if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) { - return true; - } - } - } - return false; -} +Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; -void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { auto* cfg = inputs.Get(); if (cfg == nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); - return; + b.Diagnostics().add_error(diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return Program(std::move(b)); + } + + if (!ShouldRun(src)) { + return SkipTransform; } const char* kNumWorkgroupsMemberName = "num_workgroups"; // Find all entry point parameters that declare the num_workgroups builtin. std::unordered_set to_replace; - for (auto* func : ctx.src->AST().Functions()) { + for (auto* func : src->AST().Functions()) { // num_workgroups is only valid for compute stages. if (func->PipelineStage() != ast::PipelineStage::kCompute) { continue; } - for (auto* param : ctx.src->Sem().Get(func)->Parameters()) { + for (auto* param : src->Sem().Get(func)->Parameters()) { // Because the CanonicalizeEntryPointIO transform has been run, builtins // will only appear as struct members. auto* str = param->Type()->As(); @@ -108,7 +119,7 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat // If this is the only member, remove the struct and parameter too. if (str->Members().size() == 1) { ctx.Remove(func->params, param->Declaration()); - ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration()); + ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration()); } } } @@ -119,11 +130,10 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat const ast::Variable* num_workgroups_ubo = nullptr; auto get_ubo = [&]() { if (!num_workgroups_ubo) { - auto* num_workgroups_struct = ctx.dst->Structure( - ctx.dst->Sym(), - utils::Vector{ - ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32())), - }); + auto* num_workgroups_struct = + b.Structure(b.Sym(), utils::Vector{ + b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())), + }); uint32_t group, binding; if (cfg->ubo_binding.has_value()) { @@ -135,9 +145,9 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat // plus 1, or group 0 if no resource bound. group = 0; - for (auto* global : ctx.src->AST().GlobalVariables()) { + for (auto* global : src->AST().GlobalVariables()) { if (global->HasBindingPoint()) { - auto* global_sem = ctx.src->Sem().Get(global); + auto* global_sem = src->Sem().Get(global); auto binding_point = global_sem->BindingPoint(); if (binding_point.group >= group) { group = binding_point.group + 1; @@ -148,16 +158,16 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat binding = 0; } - num_workgroups_ubo = ctx.dst->GlobalVar( - ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform, - ctx.dst->Group(AInt(group)), ctx.dst->Binding(AInt(binding))); + num_workgroups_ubo = + b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform, + b.Group(AInt(group)), b.Binding(AInt(binding))); } return num_workgroups_ubo; }; // Now replace all the places where the builtins are accessed with the value // loaded from the uniform buffer. - for (auto* node : ctx.src->ASTNodes().Objects()) { + for (auto* node : src->ASTNodes().Objects()) { auto* accessor = node->As(); if (!accessor) { continue; @@ -168,12 +178,12 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat } if (to_replace.count({ident->symbol, accessor->member->symbol})) { - ctx.Replace(accessor, - ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName)); + ctx.Replace(accessor, b.MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName)); } } ctx.Clone(); + return Program(std::move(b)); } NumWorkgroupsFromUniform::Config::Config(std::optional ubo_bp) diff --git a/src/tint/transform/num_workgroups_from_uniform.h b/src/tint/transform/num_workgroups_from_uniform.h index 292c823bc4..25308f2294 100644 --- a/src/tint/transform/num_workgroups_from_uniform.h +++ b/src/tint/transform/num_workgroups_from_uniform.h @@ -72,19 +72,10 @@ class NumWorkgroupsFromUniform final : public Castable ubo_binding; }; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/num_workgroups_from_uniform_test.cc b/src/tint/transform/num_workgroups_from_uniform_test.cc index 093081c6e7..435ab33215 100644 --- a/src/tint/transform/num_workgroups_from_uniform_test.cc +++ b/src/tint/transform/num_workgroups_from_uniform_test.cc @@ -28,7 +28,9 @@ using NumWorkgroupsFromUniformTest = TransformTest; TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) { auto* src = R"()"; - EXPECT_FALSE(ShouldRun(src)); + DataMap data; + data.Add(sem::BindingPoint{0, 30u}); + EXPECT_FALSE(ShouldRun(src, data)); } TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) { @@ -38,7 +40,9 @@ fn main(@builtin(num_workgroups) num_wgs : vec3) { } )"; - EXPECT_TRUE(ShouldRun(src)); + DataMap data; + data.Add(sem::BindingPoint{0, 30u}); + EXPECT_TRUE(ShouldRun(src, data)); } TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { @@ -55,7 +59,6 @@ fn main(@builtin(num_workgroups) num_wgs : vec3) { DataMap data; data.Add(CanonicalizeEntryPointIO::ShaderStyle::kHlsl); auto got = Run(src, data); - EXPECT_EQ(expect, str(got)); } diff --git a/src/tint/transform/packed_vec3.cc b/src/tint/transform/packed_vec3.cc index dde5aca374..e947a538db 100644 --- a/src/tint/transform/packed_vec3.cc +++ b/src/tint/transform/packed_vec3.cc @@ -33,14 +33,15 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::transform { -/// The PIMPL state for the PackedVec3 transform +/// PIMPL state for the transform struct PackedVec3::State { /// Constructor - /// @param c the CloneContext - explicit State(CloneContext& c) : ctx(c) {} + /// @param program the source program + explicit State(const Program* program) : src(program) {} /// Runs the transform - void Run() { + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { // Packed vec3 struct members utils::Hashset members; @@ -72,6 +73,10 @@ struct PackedVec3::State { } } + if (members.IsEmpty()) { + return SkipTransform; + } + // Walk the nodes, starting with the most deeply nested, finding all the AST expressions // that load a whole packed vector (not a scalar / swizzle of the vector). utils::Hashset refs; @@ -137,36 +142,20 @@ struct PackedVec3::State { } ctx.Clone(); - } - - /// @returns true if this transform should be run for the given program - /// @param program the program to inspect - static bool ShouldRun(const Program* program) { - for (auto* decl : program->AST().GlobalDeclarations()) { - if (auto* str = program->Sem().Get(decl)) { - if (str->IsHostShareable()) { - for (auto* member : str->Members()) { - if (auto* vec = member->Type()->As()) { - if (vec->Width() == 3) { - return true; - } - } - } - } - } - } - return false; + return Program(std::move(b)); } private: + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; /// The clone context - CloneContext& ctx; + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; /// Alias to the semantic info in ctx.src const sem::Info& sem = ctx.src->Sem(); /// Alias to the symbols in ctx.src const SymbolTable& sym = ctx.src->Symbols(); - /// Alias to the ctx.dst program builder - ProgramBuilder& b = *ctx.dst; }; PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {} @@ -183,12 +172,8 @@ std::string PackedVec3::Attribute::InternalName() const { PackedVec3::PackedVec3() = default; PackedVec3::~PackedVec3() = default; -bool PackedVec3::ShouldRun(const Program* program, const DataMap&) const { - return State::ShouldRun(program); -} - -void PackedVec3::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State(ctx).Run(); +Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const { + return State{src}.Run(); } } // namespace tint::transform diff --git a/src/tint/transform/packed_vec3.h b/src/tint/transform/packed_vec3.h index 9d899cbf5c..0d304fa456 100644 --- a/src/tint/transform/packed_vec3.h +++ b/src/tint/transform/packed_vec3.h @@ -56,21 +56,13 @@ class PackedVec3 final : public Castable { /// Destructor ~PackedVec3() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: struct State; - - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/pad_structs.cc b/src/tint/transform/pad_structs.cc index 10b0565cf0..4ceb39dc82 100644 --- a/src/tint/transform/pad_structs.cc +++ b/src/tint/transform/pad_structs.cc @@ -50,8 +50,10 @@ PadStructs::PadStructs() = default; PadStructs::~PadStructs() = default; -void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - auto& sem = ctx.src->Sem(); +Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + auto& sem = src->Sem(); std::unordered_map replaced_structs; utils::Hashset padding_members; @@ -65,7 +67,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { bool has_runtime_sized_array = false; utils::Vector new_members; for (auto* mem : str->Members()) { - auto name = ctx.src->Symbols().NameFor(mem->Name()); + auto name = src->Symbols().NameFor(mem->Name()); if (offset < mem->Offset()) { CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset); @@ -75,7 +77,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { auto* ty = mem->Type(); const ast::Type* type = CreateASTTypeFor(ctx, ty); - new_members.Push(ctx.dst->Member(name, type)); + new_members.Push(b.Member(name, type)); uint32_t size = ty->Size(); if (ty->Is() && str->UsedAs(ast::AddressSpace::kUniform)) { @@ -97,8 +99,8 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { if (offset < struct_size && !has_runtime_sized_array) { CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset); } - auto* new_struct = ctx.dst->create(ctx.Clone(ast_str->name), - std::move(new_members), utils::Empty); + auto* new_struct = + b.create(ctx.Clone(ast_str->name), std::move(new_members), utils::Empty); replaced_structs[ast_str] = new_struct; return new_struct; }); @@ -131,16 +133,17 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { auto* arg = ast_call->args.begin(); for (auto* member : new_struct->members) { if (padding_members.Contains(member)) { - new_args.Push(ctx.dst->Expr(0_u)); + new_args.Push(b.Expr(0_u)); } else { new_args.Push(ctx.Clone(*arg)); arg++; } } - return ctx.dst->Construct(CreateASTTypeFor(ctx, str), new_args); + return b.Construct(CreateASTTypeFor(ctx, str), new_args); }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/pad_structs.h b/src/tint/transform/pad_structs.h index 55fec749e4..e9996d4e92 100644 --- a/src/tint/transform/pad_structs.h +++ b/src/tint/transform/pad_structs.h @@ -30,14 +30,10 @@ class PadStructs final : public Castable { /// Destructor ~PadStructs() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/promote_initializers_to_let.cc b/src/tint/transform/promote_initializers_to_let.cc index 315a4cef41..9e02c45b69 100644 --- a/src/tint/transform/promote_initializers_to_let.cc +++ b/src/tint/transform/promote_initializers_to_let.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "src/tint/transform/promote_initializers_to_let.h" + +#include + #include "src/tint/program_builder.h" #include "src/tint/sem/call.h" #include "src/tint/sem/statement.h" @@ -27,9 +30,16 @@ PromoteInitializersToLet::PromoteInitializersToLet() = default; PromoteInitializersToLet::~PromoteInitializersToLet() = default; -void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + HoistToDeclBefore hoist_to_decl_before(ctx); + bool any_promoted = false; + // Hoists array and structure initializers to a constant variable, declared // just before the statement of usage. auto promote = [&](const sem::Expression* expr) { @@ -59,14 +69,15 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) return true; } + any_promoted = true; return hoist_to_decl_before.Add(expr, expr->Declaration(), true); }; - for (auto* node : ctx.src->ASTNodes().Objects()) { + for (auto* node : src->ASTNodes().Objects()) { bool ok = Switch( node, // [&](const ast::CallExpression* expr) { - if (auto* sem = ctx.src->Sem().Get(expr)) { + if (auto* sem = src->Sem().Get(expr)) { auto* ctor = sem->UnwrapMaterialize()->As(); if (ctor->Target()->Is()) { return promote(sem); @@ -75,7 +86,7 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) return true; }, [&](const ast::IdentifierExpression* expr) { - if (auto* sem = ctx.src->Sem().Get(expr)) { + if (auto* sem = src->Sem().Get(expr)) { if (auto* user = sem->UnwrapMaterialize()->As()) { // Identifier resolves to a variable if (auto* stmt = user->Stmt()) { @@ -96,13 +107,17 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) return true; }, [&](Default) { return true; }); - if (!ok) { - return; + return Program(std::move(b)); } } + if (!any_promoted) { + return SkipTransform; + } + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/promote_initializers_to_let.h b/src/tint/transform/promote_initializers_to_let.h index 78793c70eb..b1bb291a1f 100644 --- a/src/tint/transform/promote_initializers_to_let.h +++ b/src/tint/transform/promote_initializers_to_let.h @@ -33,14 +33,10 @@ class PromoteInitializersToLet final : public Castable { - class State; - void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override; + ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override; }; -class SimplifySideEffectStatements::State : public StateBase { - HoistToDeclBefore hoist_to_decl_before; +Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; - public: - explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {} + bool made_changes = false; - void Run() { - for (auto* node : ctx.src->ASTNodes().Objects()) { - if (auto* expr = node->As()) { - auto* sem_expr = sem.Get(expr); - if (!sem_expr || !sem_expr->HasSideEffects()) { - continue; - } - - hoist_to_decl_before.Prepare(sem_expr); + HoistToDeclBefore hoist_to_decl_before(ctx); + for (auto* node : ctx.src->ASTNodes().Objects()) { + if (auto* expr = node->As()) { + auto* sem_expr = src->Sem().Get(expr); + if (!sem_expr || !sem_expr->HasSideEffects()) { + continue; } - } - ctx.Clone(); - } -}; -void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State state(ctx); - state.Run(); + hoist_to_decl_before.Prepare(sem_expr); + made_changes = true; + } + } + + if (!made_changes) { + return SkipTransform; + } + + ctx.Clone(); + return Program(std::move(b)); } // Decomposes side-effecting expressions to ensure order of evaluation. This @@ -89,7 +91,7 @@ void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMa struct DecomposeSideEffects : Castable { class CollectHoistsState; class DecomposeState; - void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override; + ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override; }; // CollectHoistsState traverses the AST top-down, identifying which expressions @@ -667,12 +669,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase { } return nullptr; }); - - ctx.Clone(); } }; -void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + // First collect side-effecting expressions to hoist CollectHoistsState collect_hoists_state{ctx}; auto to_hoist = collect_hoists_state.Run(); @@ -680,6 +685,9 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons // Now decompose these expressions DecomposeState decompose_state{ctx, std::move(to_hoist)}; decompose_state.Run(); + + ctx.Clone(); + return Program(std::move(b)); } } // namespace @@ -687,13 +695,13 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default; PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default; -Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const { +Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src, + const DataMap& inputs, + DataMap& outputs) const { transform::Manager manager; manager.Add(); manager.Add(); - - auto output = manager.Run(program, data); - return output; + return manager.Apply(src, inputs, outputs); } } // namespace tint::transform diff --git a/src/tint/transform/promote_side_effects_to_decl.h b/src/tint/transform/promote_side_effects_to_decl.h index d5d1126133..99e80c65b2 100644 --- a/src/tint/transform/promote_side_effects_to_decl.h +++ b/src/tint/transform/promote_side_effects_to_decl.h @@ -31,12 +31,10 @@ class PromoteSideEffectsToDecl final : public Castable switch_to_cont_var_name; - - // If `cont` is within a switch statement within a loop, returns a pointer to - // that switch statement. - static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, - const ast::ContinueStatement* cont) { - // Find whether first parent is a switch or a loop - auto* sem_stmt = sem.Get(cont); - auto* sem_parent = sem_stmt->FindFirstParent(); - if (!sem_parent) { - return nullptr; - } - return sem_parent->Declaration()->As(); - } - - public: +/// PIMPL state for the transform +struct RemoveContinueInSwitch::State { /// Constructor - /// @param ctx_in the context - explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} - - /// Returns true if this transform should be run for the given program - static bool ShouldRun(const Program* program) { - for (auto* node : program->ASTNodes().Objects()) { - auto* stmt = node->As(); - if (!stmt) { - continue; - } - if (GetParentSwitchInLoop(program->Sem(), stmt)) { - return true; - } - } - return false; - } + /// @param program the source program + explicit State(const Program* program) : src(program) {} /// Runs the transform - void Run() { - for (auto* node : ctx.src->ASTNodes().Objects()) { + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { + bool made_changes = false; + + for (auto* node : src->ASTNodes().Objects()) { auto* cont = node->As(); if (!cont) { continue; @@ -90,6 +56,8 @@ class State { continue; } + made_changes = true; + auto cont_var_name = tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() { // Create and insert 'var tint_continue : bool = false;' before the @@ -116,22 +84,50 @@ class State { ctx.Replace(cont, new_stmt); } + if (!made_changes) { + return SkipTransform; + } + ctx.Clone(); + return Program(std::move(b)); + } + + private: + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; + /// The clone context + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; + /// Alias to src->sem + const sem::Info& sem = src->Sem(); + + // Map of switch statement to 'tint_continue' variable. + std::unordered_map switch_to_cont_var_name; + + // If `cont` is within a switch statement within a loop, returns a pointer to + // that switch statement. + static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, + const ast::ContinueStatement* cont) { + // Find whether first parent is a switch or a loop + auto* sem_stmt = sem.Get(cont); + auto* sem_parent = sem_stmt->FindFirstParent(); + if (!sem_parent) { + return nullptr; + } + return sem_parent->Declaration()->As(); } }; -} // namespace - RemoveContinueInSwitch::RemoveContinueInSwitch() = default; RemoveContinueInSwitch::~RemoveContinueInSwitch() = default; -bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const { - return State::ShouldRun(program); -} - -void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State state(ctx); - state.Run(); +Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src, + const DataMap&, + DataMap&) const { + State state(src); + return state.Run(); } } // namespace tint::transform diff --git a/src/tint/transform/remove_continue_in_switch.h b/src/tint/transform/remove_continue_in_switch.h index 9e5a4d51ad..10709069ed 100644 --- a/src/tint/transform/remove_continue_in_switch.h +++ b/src/tint/transform/remove_continue_in_switch.h @@ -31,19 +31,13 @@ class RemoveContinueInSwitch final : public CastableASTNodes().Objects()) { - if (node->Is()) { - return true; - } - if (auto* stmt = node->As()) { - if (program->Sem().Get(stmt->expr)->ConstantValue() != nullptr) { - return true; - } - } - } - return false; -} +Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; -void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - auto& sem = ctx.src->Sem(); + auto& sem = src->Sem(); - std::unordered_map> sinks; + utils::Hashmap> sinks; - for (auto* node : ctx.src->ASTNodes().Objects()) { + bool made_changes = false; + for (auto* node : src->ASTNodes().Objects()) { Switch( node, [&](const ast::AssignmentStatement* stmt) { if (stmt->lhs->Is()) { + made_changes = true; + std::vector side_effects; if (!ast::TraverseExpressions( - stmt->rhs, ctx.dst->Diagnostics(), - [&](const ast::CallExpression* expr) { + stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) { // ast::CallExpression may map to a function or builtin call // (both may have side-effects), or a type initializer or // type conversion (both do not have side effects). @@ -100,8 +91,7 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { if (auto* call = side_effects[0]->As()) { // Phony assignment with single call side effect. // Replace phony assignment with call. - ctx.Replace(stmt, - [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); }); + ctx.Replace(stmt, [&, call] { return b.CallStmt(ctx.Clone(call)); }); return; } } @@ -114,22 +104,21 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { for (auto* arg : side_effects) { sig.push_back(sem.Get(arg)->Type()->UnwrapRef()); } - auto sink = utils::GetOrCreate(sinks, sig, [&] { - auto name = ctx.dst->Symbols().New("phony_sink"); + auto sink = sinks.GetOrCreate(sig, [&] { + auto name = b.Symbols().New("phony_sink"); utils::Vector params; for (auto* ty : sig) { auto* ast_ty = CreateASTTypeFor(ctx, ty); - params.Push( - ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty)); + params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty)); } - ctx.dst->Func(name, params, ctx.dst->ty.void_(), {}); + b.Func(name, params, b.ty.void_(), {}); return name; }); utils::Vector args; for (auto* arg : side_effects) { args.Push(ctx.Clone(arg)); } - return ctx.dst->CallStmt(ctx.dst->Call(sink, args)); + return b.CallStmt(b.Call(sink, args)); }); } }, @@ -138,12 +127,18 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { // TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects. auto* sem_expr = sem.Get(stmt->expr); if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) { + made_changes = true; ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt); } }); } + if (!made_changes) { + return SkipTransform; + } + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/remove_phonies.h b/src/tint/transform/remove_phonies.h index daa181270e..99a049e3f2 100644 --- a/src/tint/transform/remove_phonies.h +++ b/src/tint/transform/remove_phonies.h @@ -33,19 +33,10 @@ class RemovePhonies final : public Castable { /// Destructor ~RemovePhonies() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/remove_unreachable_statements.cc b/src/tint/transform/remove_unreachable_statements.cc index 964d767c35..f9bf202737 100644 --- a/src/tint/transform/remove_unreachable_statements.cc +++ b/src/tint/transform/remove_unreachable_statements.cc @@ -36,27 +36,28 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default; RemoveUnreachableStatements::~RemoveUnreachableStatements() = default; -bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const { - for (auto* node : program->ASTNodes().Objects()) { - if (auto* stmt = program->Sem().Get(node)) { +Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + bool made_changes = false; + for (auto* node : src->ASTNodes().Objects()) { + if (auto* stmt = src->Sem().Get(node)) { if (!stmt->IsReachable()) { - return true; + RemoveStatement(ctx, stmt->Declaration()); + made_changes = true; } } } - return false; -} -void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - for (auto* node : ctx.src->ASTNodes().Objects()) { - if (auto* stmt = ctx.src->Sem().Get(node)) { - if (!stmt->IsReachable()) { - RemoveStatement(ctx, stmt->Declaration()); - } - } + if (!made_changes) { + return SkipTransform; } ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/remove_unreachable_statements.h b/src/tint/transform/remove_unreachable_statements.h index 7f8b9472ad..f5848f5fd8 100644 --- a/src/tint/transform/remove_unreachable_statements.h +++ b/src/tint/transform/remove_unreachable_statements.h @@ -32,19 +32,10 @@ class RemoveUnreachableStatements final : public Castable preserve; - for (auto* node : in->ASTNodes().Objects()) { + utils::Hashset preserve; + for (auto* node : src->ASTNodes().Objects()) { if (auto* member = node->As()) { - auto* sem = in->Sem().Get(member); - if (!sem) { - TINT_ICE(Transform, out.Diagnostics()) - << "MemberAccessorExpression has no semantic info"; - continue; - } + auto* sem = src->Sem().Get(member); if (sem->Is()) { - preserve.emplace(member->member); - } else if (auto* str_expr = in->Sem().Get(member->structure)) { + preserve.Add(member->member); + } else if (auto* str_expr = src->Sem().Get(member->structure)) { if (auto* ty = str_expr->Type()->UnwrapRef()->As()) { if (ty->Declaration() == nullptr) { // Builtin structure - preserve.emplace(member->member); + preserve.Add(member->member); } } } } else if (auto* call = node->As()) { - auto* sem = in->Sem().Get(call)->UnwrapMaterialize()->As(); - if (!sem) { - TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info"; - continue; - } + auto* sem = src->Sem().Get(call)->UnwrapMaterialize()->As(); if (sem->Target()->Is()) { - preserve.emplace(call->target.name); + preserve.Add(call->target.name); } } } @@ -1300,7 +1292,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const { } ctx.ReplaceAll([&](Symbol sym_in) { - auto name_in = ctx.src->Symbols().NameFor(sym_in); + auto name_in = src->Symbols().NameFor(sym_in); if (preserve_unicode || text::utf8::IsASCII(name_in)) { switch (target) { case Target::kAll: @@ -1343,17 +1335,20 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const { }); ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* { - if (preserve.count(ident)) { + if (preserve.Contains(ident)) { auto sym_in = ident->symbol; - auto str = in->Symbols().NameFor(sym_in); - auto sym_out = out.Symbols().Register(str); + auto str = src->Symbols().NameFor(sym_in); + auto sym_out = b.Symbols().Register(str); return ctx.dst->create(ctx.Clone(ident->source), sym_out); } return nullptr; // Clone ident. Uses the symbol remapping above. }); - ctx.Clone(); - return Output(Program(std::move(out)), std::make_unique(std::move(remappings))); + ctx.Clone(); // Must come before the std::move() + + outputs.Add(std::move(remappings)); + + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/renamer.h b/src/tint/transform/renamer.h index 000aee9ca2..8a9f97e952 100644 --- a/src/tint/transform/renamer.h +++ b/src/tint/transform/renamer.h @@ -85,11 +85,10 @@ class Renamer final : public Castable { /// Destructor ~Renamer() override; - /// Runs the transform on `program`, returning the transformation result. - /// @param program the source program to transform - /// @param data optional extra transform-specific input data - /// @returns the transformation result - Output Run(const Program* program, const DataMap& data = {}) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc index a22f84fcdf..75d63f25b5 100644 --- a/src/tint/transform/robustness.cc +++ b/src/tint/transform/robustness.cc @@ -33,36 +33,48 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::transform { -/// State holds the current transform state +/// PIMPL state for the transform struct Robustness::State { - /// The clone context - CloneContext& ctx; + /// Constructor + /// @param program the source program + /// @param omitted the omitted address spaces + State(const Program* program, std::unordered_set&& omitted) + : src(program), omitted_address_spaces(std::move(omitted)) {} - /// Set of address spacees to not apply the transform to - std::unordered_set omitted_classes; - - /// Applies the transformation state to `ctx`. - void Transform() { + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); }); ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); }); + + ctx.Clone(); + return Program(std::move(b)); } + private: + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; + /// The clone context + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; + + /// Set of address spaces to not apply the transform to + std::unordered_set omitted_address_spaces; + /// Apply bounds clamping to array, vector and matrix indexing /// @param expr the array, vector or matrix index expression /// @return the clamped replacement expression, or nullptr if `expr` should be cloned without /// changes. const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) { - auto* sem = - ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As(); + auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As(); auto* ret_type = sem->Type(); auto* ref = ret_type->As(); - if (ref && omitted_classes.count(ref->AddressSpace()) != 0) { + if (ref && omitted_address_spaces.count(ref->AddressSpace()) != 0) { return nullptr; } - ProgramBuilder& b = *ctx.dst; - // idx return the cloned index expression, as a u32. auto idx = [&]() -> const ast::Expression* { auto* i = ctx.Clone(expr->index); @@ -109,8 +121,8 @@ struct Robustness::State { } else { // Note: Don't be tempted to use the array override variable as an expression // here, the name might be shadowed! - ctx.dst->Diagnostics().add_error(diag::System::Transform, - sem::Array::kErrExpectedConstantCount); + b.Diagnostics().add_error(diag::System::Transform, + sem::Array::kErrExpectedConstantCount); return nullptr; } @@ -119,7 +131,7 @@ struct Robustness::State { [&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled object type in robustness of array index: " - << ctx.src->FriendlyName(ret_type->UnwrapRef()); + << src->FriendlyName(ret_type->UnwrapRef()); return nullptr; }); @@ -127,9 +139,9 @@ struct Robustness::State { return nullptr; // Clamping not needed } - auto src = ctx.Clone(expr->source); - auto* obj = ctx.Clone(expr->object); - return b.IndexAccessor(src, obj, clamped_idx); + auto idx_src = ctx.Clone(expr->source); + auto* idx_obj = ctx.Clone(expr->object); + return b.IndexAccessor(idx_src, idx_obj, clamped_idx); } /// @param type builtin type @@ -145,15 +157,13 @@ struct Robustness::State { /// @return the clamped replacement call expression, or nullptr if `expr` /// should be cloned without changes. const ast::CallExpression* Transform(const ast::CallExpression* expr) { - auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As(); + auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As(); auto* call_target = call->Target(); auto* builtin = call_target->As(); if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) { return nullptr; // No transform, just clone. } - ProgramBuilder& b = *ctx.dst; - // Indices of the mandatory texture and coords parameters, and the optional // array and level parameters. auto& signature = builtin->Signature(); @@ -261,7 +271,7 @@ struct Robustness::State { // Clamp the level argument, if provided if (level_idx >= 0) { auto* arg = expr->args[static_cast(level_idx)]; - ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0_a)); + ctx.Replace(arg, level_arg ? level_arg() : b.Expr(0_a)); } return nullptr; // Clone, which will use the argument replacements above. @@ -276,28 +286,27 @@ Robustness::Config& Robustness::Config::operator=(const Config&) = default; Robustness::Robustness() = default; Robustness::~Robustness() = default; -void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { +Transform::ApplyResult Robustness::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { Config cfg; if (auto* cfg_data = inputs.Get()) { cfg = *cfg_data; } - std::unordered_set omitted_classes; - for (auto sc : cfg.omitted_classes) { + std::unordered_set omitted_address_spaces; + for (auto sc : cfg.omitted_address_spaces) { switch (sc) { case AddressSpace::kUniform: - omitted_classes.insert(ast::AddressSpace::kUniform); + omitted_address_spaces.insert(ast::AddressSpace::kUniform); break; case AddressSpace::kStorage: - omitted_classes.insert(ast::AddressSpace::kStorage); + omitted_address_spaces.insert(ast::AddressSpace::kStorage); break; } } - State state{ctx, std::move(omitted_classes)}; - - state.Transform(); - ctx.Clone(); + return State{src, std::move(omitted_address_spaces)}.Run(); } } // namespace tint::transform diff --git a/src/tint/transform/robustness.h b/src/tint/transform/robustness.h index 21a7ff9a01..14c5fe12a3 100644 --- a/src/tint/transform/robustness.h +++ b/src/tint/transform/robustness.h @@ -54,9 +54,9 @@ class Robustness final : public Castable { /// @returns this Config Config& operator=(const Config&); - /// Address spacees to omit from apply the transform to. + /// Address spaces to omit from apply the transform to. /// This allows for optimizing on hardware that provide safe accesses. - std::unordered_set omitted_classes; + std::unordered_set omitted_address_spaces; }; /// Constructor @@ -64,14 +64,10 @@ class Robustness final : public Castable { /// Destructor ~Robustness() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: struct State; diff --git a/src/tint/transform/robustness_test.cc b/src/tint/transform/robustness_test.cc index 16d958f21d..990bbde654 100644 --- a/src/tint/transform/robustness_test.cc +++ b/src/tint/transform/robustness_test.cc @@ -1274,7 +1274,7 @@ fn f() { )"; Robustness::Config cfg; - cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage); + cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage); DataMap data; data.Add(cfg); @@ -1325,7 +1325,7 @@ fn f() { )"; Robustness::Config cfg; - cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform); + cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform); DataMap data; data.Add(cfg); @@ -1376,8 +1376,8 @@ fn f() { )"; Robustness::Config cfg; - cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage); - cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform); + cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage); + cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform); DataMap data; data.Add(cfg); diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc index ea35699230..b2b99ed470 100644 --- a/src/tint/transform/simplify_pointers.cc +++ b/src/tint/transform/simplify_pointers.cc @@ -45,14 +45,18 @@ struct PointerOp { } // namespace -/// The PIMPL state for the SimplifyPointers transform +/// PIMPL state for the transform struct SimplifyPointers::State { + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; /// The clone context - CloneContext& ctx; + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; /// Constructor - /// @param context the clone context - explicit State(CloneContext& context) : ctx(context) {} + /// @param program the source program + explicit State(const Program* program) : src(program) {} /// Traverses the expression `expr` looking for non-literal array indexing /// expressions that would affect the computed address of a pointer @@ -120,10 +124,11 @@ struct SimplifyPointers::State { } } - /// Performs the transformation - void Run() { + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { // A map of saved expressions to their saved variable name - std::unordered_map saved_vars; + utils::Hashmap saved_vars; // Register the ast::Expression transform handler. // This performs two different transformations: @@ -135,9 +140,8 @@ struct SimplifyPointers::State { // variable identifier. ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { // Look to see if we need to swap this Expression with a saved variable. - auto it = saved_vars.find(expr); - if (it != saved_vars.end()) { - return ctx.dst->Expr(it->second); + if (auto* saved_var = saved_vars.Find(expr)) { + return ctx.dst->Expr(*saved_var); } // Reduce the expression, folding away chains of address-of / indirections @@ -174,7 +178,7 @@ struct SimplifyPointers::State { // Scan the initializer expression for array index expressions that need // to be hoist to temporary "saved" variables. - std::vector saved; + utils::Vector saved; CollectSavedArrayIndices( var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { // We have a sub-expression that needs to be saved. @@ -182,18 +186,18 @@ struct SimplifyPointers::State { auto saved_name = ctx.dst->Symbols().New( ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); - saved.emplace_back(decl); + saved.Push(decl); // Record the substitution of `idx_expr` to the saved variable // with the symbol `saved_name`. This will be used by the // ReplaceAll() handler above. - saved_vars.emplace(idx_expr, saved_name); + saved_vars.Add(idx_expr, saved_name); }); // Find the place to insert the saved declarations. // Special care needs to be made for lets declared as the initializer // part of for-loops. In this case the block will hold the for-loop // statement, not the let. - if (!saved.empty()) { + if (!saved.IsEmpty()) { auto* stmt = ctx.src->Sem().Get(let); auto* block = stmt->Block(); // Find the statement owned by the block (either the let decl or a @@ -219,7 +223,9 @@ struct SimplifyPointers::State { RemoveStatement(ctx, let); } } + ctx.Clone(); + return Program(std::move(b)); } }; @@ -227,8 +233,8 @@ SimplifyPointers::SimplifyPointers() = default; SimplifyPointers::~SimplifyPointers() = default; -void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State(ctx).Run(); +Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const { + return State(src).Run(); } } // namespace tint::transform diff --git a/src/tint/transform/simplify_pointers.h b/src/tint/transform/simplify_pointers.h index 787c7d815f..6e040bb6d6 100644 --- a/src/tint/transform/simplify_pointers.h +++ b/src/tint/transform/simplify_pointers.h @@ -39,16 +39,13 @@ class SimplifyPointers final : public Castable { /// Destructor ~SimplifyPointers() override; - protected: - struct State; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + private: + struct State; }; } // namespace tint::transform diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc index 8d26a7f955..87787ae0cc 100644 --- a/src/tint/transform/single_entry_point.cc +++ b/src/tint/transform/single_entry_point.cc @@ -30,33 +30,37 @@ SingleEntryPoint::SingleEntryPoint() = default; SingleEntryPoint::~SingleEntryPoint() = default; -void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { +Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + auto* cfg = inputs.Get(); if (cfg == nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); - - return; + b.Diagnostics().add_error(diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return Program(std::move(b)); } // Find the target entry point. const ast::Function* entry_point = nullptr; - for (auto* f : ctx.src->AST().Functions()) { + for (auto* f : src->AST().Functions()) { if (!f->IsEntryPoint()) { continue; } - if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) { + if (src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) { entry_point = f; break; } } if (entry_point == nullptr) { - ctx.dst->Diagnostics().add_error(diag::System::Transform, - "entry point '" + cfg->entry_point_name + "' not found"); - return; + b.Diagnostics().add_error(diag::System::Transform, + "entry point '" + cfg->entry_point_name + "' not found"); + return Program(std::move(b)); } - auto& sem = ctx.src->Sem(); + auto& sem = src->Sem(); // Build set of referenced module-scope variables for faster lookups later. std::unordered_set referenced_vars; @@ -66,12 +70,12 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c // Clone any module-scope variables, types, and functions that are statically referenced by the // target entry point. - for (auto* decl : ctx.src->AST().GlobalDeclarations()) { + for (auto* decl : src->AST().GlobalDeclarations()) { Switch( decl, // [&](const ast::TypeDecl* ty) { // TODO(jrprice): Strip unused types. - ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); + b.AST().AddTypeDecl(ctx.Clone(ty)); }, [&](const ast::Override* override) { if (referenced_vars.count(override)) { @@ -80,37 +84,39 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c // so that its allocated ID so that it won't be affected by other // stripped away overrides auto* global = sem.Get(override); - const auto* id = ctx.dst->Id(global->OverrideId()); + const auto* id = b.Id(global->OverrideId()); ctx.InsertFront(override->attributes, id); } - ctx.dst->AST().AddGlobalVariable(ctx.Clone(override)); + b.AST().AddGlobalVariable(ctx.Clone(override)); } }, [&](const ast::Var* var) { if (referenced_vars.count(var)) { - ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); + b.AST().AddGlobalVariable(ctx.Clone(var)); } }, [&](const ast::Const* c) { // Always keep 'const' declarations, as these can be used by attributes and array // sizes, which are not tracked as transitively used by functions. They also don't // typically get emitted by the backend unless they're actually used. - ctx.dst->AST().AddGlobalVariable(ctx.Clone(c)); + b.AST().AddGlobalVariable(ctx.Clone(c)); }, [&](const ast::Function* func) { if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { - ctx.dst->AST().AddFunction(ctx.Clone(func)); + b.AST().AddFunction(ctx.Clone(func)); } }, - [&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); }, + [&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); }, [&](Default) { - TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) + TINT_UNREACHABLE(Transform, b.Diagnostics()) << "unhandled global declaration: " << decl->TypeInfo().name; }); } // Clone the entry point. - ctx.dst->AST().AddFunction(ctx.Clone(entry_point)); + b.AST().AddFunction(ctx.Clone(entry_point)); + + return Program(std::move(b)); } SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {} diff --git a/src/tint/transform/single_entry_point.h b/src/tint/transform/single_entry_point.h index 59aa021466..7aba5e8157 100644 --- a/src/tint/transform/single_entry_point.h +++ b/src/tint/transform/single_entry_point.h @@ -53,14 +53,10 @@ class SingleEntryPoint final : public Castable { /// Destructor ~SingleEntryPoint() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc index 127c702b0b..eba66813bf 100644 --- a/src/tint/transform/spirv_atomic.cc +++ b/src/tint/transform/spirv_atomic.cc @@ -37,7 +37,7 @@ namespace tint::transform { using namespace tint::number_suffixes; // NOLINT -/// Private implementation of transform +/// PIMPL state for the transform struct SpirvAtomic::State { private: /// A struct that has been forked because a subset of members were made atomic. @@ -46,19 +46,24 @@ struct SpirvAtomic::State { std::unordered_set atomic_members; }; - CloneContext& ctx; - ProgramBuilder& b = *ctx.dst; + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; + /// The clone context + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; std::unordered_map forked_structs; std::unordered_set atomic_variables; utils::UniqueVector atomic_expressions; public: /// Constructor - /// @param c the clone context - explicit State(CloneContext& c) : ctx(c) {} + /// @param program the source program + explicit State(const Program* program) : src(program) {} /// Runs the transform - void Run() { + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { // Look for stub functions generated by the SPIR-V reader, which are used as placeholders // for atomic builtin calls. for (auto* fn : ctx.src->AST().Functions()) { @@ -102,6 +107,10 @@ struct SpirvAtomic::State { } } + if (atomic_expressions.IsEmpty()) { + return SkipTransform; + } + // Transform all variables and structure members that were used in atomic operations as // atomic types. This propagates up originating expression chains. ProcessAtomicExpressions(); @@ -143,6 +152,7 @@ struct SpirvAtomic::State { ReplaceLoadsAndStores(); ctx.Clone(); + return Program(std::move(b)); } private: @@ -297,17 +307,8 @@ const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const { ctx->dst->AllocateNodeID(), builtin); } -bool SpirvAtomic::ShouldRun(const Program* program, const DataMap&) const { - for (auto* fn : program->AST().Functions()) { - if (ast::HasAttribute(fn->attributes)) { - return true; - } - } - return false; -} - -void SpirvAtomic::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State{ctx}.Run(); +Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const { + return State{src}.Run(); } } // namespace tint::transform diff --git a/src/tint/transform/spirv_atomic.h b/src/tint/transform/spirv_atomic.h index e1311c5952..0f99dba256 100644 --- a/src/tint/transform/spirv_atomic.h +++ b/src/tint/transform/spirv_atomic.h @@ -63,21 +63,13 @@ class SpirvAtomic final : public Castable { const sem::BuiltinType builtin; }; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; - protected: + private: struct State; - - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc index 920fc69f32..0746bda412 100644 --- a/src/tint/transform/std140.cc +++ b/src/tint/transform/std140.cc @@ -77,14 +77,20 @@ struct Hasher { namespace tint::transform { -/// The PIMPL state for the Std140 transform +/// PIMPL state for the transform struct Std140::State { /// Constructor - /// @param c the CloneContext - explicit State(CloneContext& c) : ctx(c) {} + /// @param program the source program + explicit State(const Program* program) : src(program) {} /// Runs the transform - void Run() { + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { + if (!ShouldRun()) { + // Transform is not required + return SkipTransform; + } + // Begin by creating forked types for any type that is used as a uniform buffer, that // either directly or transitively contains a matrix that needs splitting for std140 layout. ForkTypes(); @@ -116,11 +122,11 @@ struct Std140::State { }); ctx.Clone(); + return Program(std::move(b)); } /// @returns true if this transform should be run for the given program - /// @param program the program to inspect - static bool ShouldRun(const Program* program) { + bool ShouldRun() const { // Returns true if the type needs to be forked for std140 usage. auto needs_fork = [&](const sem::Type* ty) { while (auto* arr = ty->As()) { @@ -135,7 +141,7 @@ struct Std140::State { }; // Scan structures for members that need forking - for (auto* ty : program->Types()) { + for (auto* ty : src->Types()) { if (auto* str = ty->As()) { if (str->UsedAs(ast::AddressSpace::kUniform)) { for (auto* member : str->Members()) { @@ -148,8 +154,8 @@ struct Std140::State { } // Scan uniform variables that have types that need forking - for (auto* decl : program->AST().GlobalVariables()) { - auto* global = program->Sem().Get(decl); + for (auto* decl : src->AST().GlobalVariables()) { + auto* global = src->Sem().Get(decl); if (global->AddressSpace() == ast::AddressSpace::kUniform) { if (needs_fork(global->Type()->UnwrapRef())) { return true; @@ -197,14 +203,16 @@ struct Std140::State { } }; + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; /// The clone context - CloneContext& ctx; - /// Alias to the semantic info in ctx.src - const sem::Info& sem = ctx.src->Sem(); - /// Alias to the symbols in ctx.src - const SymbolTable& sym = ctx.src->Symbols(); - /// Alias to the ctx.dst program builder - ProgramBuilder& b = *ctx.dst; + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; + /// Alias to the semantic info in src + const sem::Info& sem = src->Sem(); + /// Alias to the symbols in src + const SymbolTable& sym = src->Symbols(); /// Map of load function signature, to the generated function utils::Hashmap load_fns; @@ -218,7 +226,7 @@ struct Std140::State { // Map of original structure to 'std140' forked structure utils::Hashmap std140_structs; - // Map of structure member in ctx.src of a matrix type, to list of decomposed column + // Map of structure member in src of a matrix type, to list of decomposed column // members in ctx.dst. utils::Hashmap, 8> std140_mat_members; @@ -232,7 +240,7 @@ struct Std140::State { utils::Vector columns; }; - // Map of matrix type in ctx.src, to decomposed column structure in ctx.dst. + // Map of matrix type in src, to decomposed column structure in ctx.dst. utils::Hashmap std140_mats; /// AccessChain describes a chain of access expressions to uniform buffer variable. @@ -266,7 +274,7 @@ struct Std140::State { /// map (via Std140Type()). void ForkTypes() { // For each module scope declaration... - for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) { + for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) { // Check to see if this is a structure used by a uniform buffer... auto* str = sem.Get(global); if (str && str->UsedAs(ast::AddressSpace::kUniform)) { @@ -317,7 +325,7 @@ struct Std140::State { if (fork_std140) { // Clone any members that have not already been cloned. for (auto& member : members) { - if (member->program_id == ctx.src->ID()) { + if (member->program_id == src->ID()) { member = ctx.Clone(member); } } @@ -326,7 +334,7 @@ struct Std140::State { auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140"); auto* std140 = b.create(name, std::move(members), ctx.Clone(str->Declaration()->attributes)); - ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140); + ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140); std140_structs.Add(str, name); } } @@ -337,14 +345,13 @@ struct Std140::State { /// type that has been forked for std140-layout. /// Populates the #std140_uniforms set. void ReplaceUniformVarTypes() { - for (auto* global : ctx.src->AST().GlobalVariables()) { + for (auto* global : src->AST().GlobalVariables()) { if (auto* var = global->As()) { if (var->declared_address_space == ast::AddressSpace::kUniform) { auto* v = sem.Get(var); if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) { ctx.Replace(global->type, std140_ty); std140_uniforms.Add(v); - continue; } } } @@ -404,7 +411,7 @@ struct Std140::State { auto std140_mat = std140_mats.GetOrCreate(mat, [&] { auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) + "_" + - ctx.src->FriendlyName(mat->type())); + src->FriendlyName(mat->type())); auto members = DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size()); b.Structure(name, members); @@ -421,7 +428,7 @@ struct Std140::State { if (auto* std140 = Std140Type(arr->ElemType())) { utils::Vector attrs; if (!arr->IsStrideImplicit()) { - attrs.Push(ctx.dst->create(arr->Stride())); + attrs.Push(b.create(arr->Stride())); } auto count = arr->ConstantCount(); if (!count) { @@ -429,7 +436,7 @@ struct Std140::State { // * Override-expression counts can only be applied to workgroup arrays, and // this method only handles types transitively used as uniform buffers. // * Runtime-sized arrays cannot be used in uniform buffers. - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; count = 1; } @@ -440,7 +447,7 @@ struct Std140::State { }); } - /// @param mat the matrix to decompose (in ctx.src) + /// @param mat the matrix to decompose (in src) /// @param name_prefix the name prefix to apply to each of the returned column vector members. /// @param align the alignment in bytes of the matrix. /// @param size the size in bytes of the matrix. @@ -473,7 +480,7 @@ struct Std140::State { // Build the member const auto col_name = name_prefix + std::to_string(i); const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType()); - const auto* col_member = ctx.dst->Member(col_name, col_ty, std::move(attributes)); + const auto* col_member = b.Member(col_name, col_ty, std::move(attributes)); // Record the member for std140_mat_members out.Push(col_member); } @@ -618,7 +625,7 @@ struct Std140::State { /// @returns a name suffix for a std140 -> non-std140 conversion function based on the type /// being converted. - const std::string ConvertSuffix(const sem::Type* ty) const { + const std::string ConvertSuffix(const sem::Type* ty) { return Switch( ty, // [&](const sem::Struct* str) { return sym.NameFor(str->Name()); }, @@ -629,8 +636,7 @@ struct Std140::State { // * Override-expression counts can only be applied to workgroup arrays, and // this method only handles types transitively used as uniform buffers. // * Runtime-sized arrays cannot be used in uniform buffers. - TINT_ICE(Transform, ctx.dst->Diagnostics()) - << "unexpected non-constant array count"; + TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; count = 1; } return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType()); @@ -642,7 +648,7 @@ struct Std140::State { [&](const sem::F32*) { return "f32"; }, [&](Default) { TINT_ICE(Transform, b.Diagnostics()) - << "unhandled type for conversion name: " << ctx.src->FriendlyName(ty); + << "unhandled type for conversion name: " << src->FriendlyName(ty); return ""; }); } @@ -718,8 +724,7 @@ struct Std140::State { stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args)))); } else { TINT_ICE(Transform, b.Diagnostics()) - << "failed to find std140 matrix info for: " - << ctx.src->FriendlyName(ty); + << "failed to find std140 matrix info for: " << src->FriendlyName(ty); } }, // [&](const sem::Array* arr) { @@ -736,7 +741,7 @@ struct Std140::State { // * Override-expression counts can only be applied to workgroup arrays, and // this method only handles types transitively used as uniform buffers. // * Runtime-sized arrays cannot be used in uniform buffers. - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; count = 1; } @@ -749,7 +754,7 @@ struct Std140::State { }, [&](Default) { TINT_ICE(Transform, b.Diagnostics()) - << "unhandled type for conversion: " << ctx.src->FriendlyName(ty); + << "unhandled type for conversion: " << src->FriendlyName(ty); }); // Generate the function @@ -1063,7 +1068,7 @@ struct Std140::State { if (std::get_if(&access)) { const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol)); - const auto name = ctx.src->Symbols().NameFor(chain.var->Declaration()->symbol); + const auto name = src->Symbols().NameFor(chain.var->Declaration()->symbol); ty = chain.var->Type()->UnwrapRef(); return {expr, ty, name}; } @@ -1090,7 +1095,7 @@ struct Std140::State { }, // [&](Default) -> ExprTypeName { TINT_ICE(Transform, b.Diagnostics()) - << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); + << "unhandled type for access chain: " << src->FriendlyName(ty); return {}; }); } @@ -1104,14 +1109,14 @@ struct Std140::State { for (auto el : *swizzle) { rhs += xyzw[el]; } - auto swizzle_ty = ctx.src->Types().Find( + auto swizzle_ty = src->Types().Find( vec->type(), static_cast(swizzle->Length())); auto* expr = b.MemberAccessor(lhs, rhs); return {expr, swizzle_ty, rhs}; }, // [&](Default) -> ExprTypeName { TINT_ICE(Transform, b.Diagnostics()) - << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); + << "unhandled type for access chain: " << src->FriendlyName(ty); return {}; }); } @@ -1140,7 +1145,7 @@ struct Std140::State { }, // [&](Default) -> ExprTypeName { TINT_ICE(Transform, b.Diagnostics()) - << "unhandled type for access chain: " << ctx.src->FriendlyName(ty); + << "unhandled type for access chain: " << src->FriendlyName(ty); return {}; }); } @@ -1150,12 +1155,8 @@ Std140::Std140() = default; Std140::~Std140() = default; -bool Std140::ShouldRun(const Program* program, const DataMap&) const { - return State::ShouldRun(program); -} - -void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State(ctx).Run(); +Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const { + return State(src).Run(); } } // namespace tint::transform diff --git a/src/tint/transform/std140.h b/src/tint/transform/std140.h index ec5cad5cf3..49e663daf0 100644 --- a/src/tint/transform/std140.h +++ b/src/tint/transform/std140.h @@ -34,21 +34,13 @@ class Std140 final : public Castable { /// Destructor ~Std140() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: struct State; - - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/substitute_override.cc b/src/tint/transform/substitute_override.cc index 2de04e097e..7c2d0a2770 100644 --- a/src/tint/transform/substitute_override.cc +++ b/src/tint/transform/substitute_override.cc @@ -15,6 +15,7 @@ #include "src/tint/transform/substitute_override.h" #include +#include #include "src/tint/program_builder.h" #include "src/tint/sem/builtin.h" @@ -25,12 +26,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride); TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config); namespace tint::transform { +namespace { -SubstituteOverride::SubstituteOverride() = default; - -SubstituteOverride::~SubstituteOverride() = default; - -bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { for (auto* node : program->AST().GlobalVariables()) { if (node->Is()) { return true; @@ -39,18 +37,32 @@ bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const return false; } -void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) const { +} // namespace + +SubstituteOverride::SubstituteOverride() = default; + +SubstituteOverride::~SubstituteOverride() = default; + +Transform::ApplyResult SubstituteOverride::Apply(const Program* src, + const DataMap& config, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + const auto* data = config.Get(); if (!data) { - ctx.dst->Diagnostics().add_error(diag::System::Transform, - "Missing override substitution data"); - return; + b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data"); + return Program(std::move(b)); + } + + if (!ShouldRun(ctx.src)) { + return SkipTransform; } ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* { auto* sem = ctx.src->Sem().Get(w); - auto src = ctx.Clone(w->source); + auto source = ctx.Clone(w->source); auto sym = ctx.Clone(w->symbol); auto* ty = ctx.Clone(w->type); @@ -58,30 +70,30 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) auto iter = data->map.find(sem->OverrideId()); if (iter == data->map.end()) { if (!w->initializer) { - ctx.dst->Diagnostics().add_error( + b.Diagnostics().add_error( diag::System::Transform, "Initializer not provided for override, and override not overridden."); return nullptr; } - return ctx.dst->Const(src, sym, ty, ctx.Clone(w->initializer)); + return b.Const(source, sym, ty, ctx.Clone(w->initializer)); } auto value = iter->second; auto* ctor = Switch( sem->Type(), - [&](const sem::Bool*) { return ctx.dst->Expr(!std::equal_to()(value, 0.0)); }, - [&](const sem::I32*) { return ctx.dst->Expr(i32(value)); }, - [&](const sem::U32*) { return ctx.dst->Expr(u32(value)); }, - [&](const sem::F32*) { return ctx.dst->Expr(f32(value)); }, - [&](const sem::F16*) { return ctx.dst->Expr(f16(value)); }); + [&](const sem::Bool*) { return b.Expr(!std::equal_to()(value, 0.0)); }, + [&](const sem::I32*) { return b.Expr(i32(value)); }, + [&](const sem::U32*) { return b.Expr(u32(value)); }, + [&](const sem::F32*) { return b.Expr(f32(value)); }, + [&](const sem::F16*) { return b.Expr(f16(value)); }); if (!ctor) { - ctx.dst->Diagnostics().add_error(diag::System::Transform, - "Failed to create override-expression"); + b.Diagnostics().add_error(diag::System::Transform, + "Failed to create override-expression"); return nullptr; } - return ctx.dst->Const(src, sym, ty, ctor); + return b.Const(source, sym, ty, ctor); }); // Ensure that objects that are indexed with an override-expression are materialized. @@ -89,11 +101,10 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) // resulting type of the index may change. See: crbug.com/tint/1697. ctx.ReplaceAll( [&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { - if (auto* sem = ctx.src->Sem().Get(expr)) { + if (auto* sem = src->Sem().Get(expr)) { if (auto* access = sem->UnwrapMaterialize()->As()) { if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() && access->Index()->Stage() == sem::EvaluationStage::kOverride) { - auto& b = *ctx.dst; auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize), ctx.Clone(expr->object)); return b.IndexAccessor(obj, ctx.Clone(expr->index)); @@ -104,6 +115,7 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) }); ctx.Clone(); + return Program(std::move(b)); } SubstituteOverride::Config::Config() = default; diff --git a/src/tint/transform/substitute_override.h b/src/tint/transform/substitute_override.h index 940e11d2fa..853acc743b 100644 --- a/src/tint/transform/substitute_override.h +++ b/src/tint/transform/substitute_override.h @@ -75,19 +75,10 @@ class SubstituteOverride final : public Castable /// Destructor ~SubstituteOverride() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/test_helper.h b/src/tint/transform/test_helper.h index bc82fe5a4f..ac48c5fc61 100644 --- a/src/tint/transform/test_helper.h +++ b/src/tint/transform/test_helper.h @@ -122,7 +122,18 @@ class TransformTestBase : public BASE { } const Transform& t = TRANSFORM(); - return t.ShouldRun(&program, data); + + DataMap outputs; + auto result = t.Apply(&program, data, outputs); + if (!result) { + return false; + } + if (!result->IsValid()) { + ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: " + << result->Diagnostics().str(); + return true; + } + return result.has_value(); } /// @param in the input WGSL source diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc index 3e034112b5..c37f3b49a7 100644 --- a/src/tint/transform/transform.cc +++ b/src/tint/transform/transform.cc @@ -46,24 +46,19 @@ Output::Output(Program&& p) : program(std::move(p)) {} Transform::Transform() = default; Transform::~Transform() = default; -Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const { - ProgramBuilder builder; - CloneContext ctx(&builder, program); +Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const { Output output; - Run(ctx, data, output.data); - output.program = Program(std::move(builder)); + if (auto program = Apply(src, data, output.data)) { + output.program = std::move(program.value()); + } else { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + ctx.Clone(); + output.program = Program(std::move(b)); + } return output; } -void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics()) - << "Transform::Run() unimplemented for " << TypeInfo().name; -} - -bool Transform::ShouldRun(const Program*, const DataMap&) const { - return true; -} - void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) { auto* sem = ctx.src->Sem().Get(stmt); if (auto* block = tint::As(sem->Parent())) { diff --git a/src/tint/transform/transform.h b/src/tint/transform/transform.h index c3e3d1d811..6580e25be0 100644 --- a/src/tint/transform/transform.h +++ b/src/tint/transform/transform.h @@ -158,26 +158,30 @@ class Transform : public Castable { /// Destructor ~Transform() override; - /// Runs the transform on `program`, returning the transformation result. + /// Runs the transform on @p program, returning the transformation result or a clone of + /// @p program. /// @param program the source program to transform /// @param data optional extra transform-specific input data /// @returns the transformation result - virtual Output Run(const Program* program, const DataMap& data = {}) const; + Output Run(const Program* program, const DataMap& data = {}) const; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const; + /// The return value of Apply(). + /// If SkipTransform (std::nullopt), then the transform is not needed to be run. + using ApplyResult = std::optional; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder + /// Value returned from Apply() to indicate that the transform does not need to be run + static inline constexpr std::nullopt_t SkipTransform = std::nullopt; + + /// Runs the transform on `program`, return. + /// @param program the input program /// @param inputs optional extra transform-specific input data /// @param outputs optional extra transform-specific output data - virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const; + /// @returns a transformed program, or std::nullopt if the transform didn't need to run. + virtual ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const = 0; + protected: /// Removes the statement `stmt` from the transformed program. /// RemoveStatement handles edge cases, like statements in the initializer and /// continuing of for-loops. diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc index d063ba2377..82fdf6a8e6 100644 --- a/src/tint/transform/transform_test.cc +++ b/src/tint/transform/transform_test.cc @@ -23,7 +23,9 @@ namespace { // Inherit from Transform so we have access to protected methods struct CreateASTTypeForTest : public testing::Test, public Transform { - Output Run(const Program*, const DataMap&) const override { return {}; } + ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override { + return SkipTransform; + } const ast::Type* create(std::function create_sem_type) { ProgramBuilder sem_type_builder; diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc index 975e2ed236..93ce595a13 100644 --- a/src/tint/transform/unshadow.cc +++ b/src/tint/transform/unshadow.cc @@ -28,27 +28,32 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow); namespace tint::transform { -/// The PIMPL state for the Unshadow transform +/// PIMPL state for the transform struct Unshadow::State { + /// The source program + const Program* const src; + /// The target program builder + ProgramBuilder b; /// The clone context - CloneContext& ctx; + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; /// Constructor - /// @param context the clone context - explicit State(CloneContext& context) : ctx(context) {} + /// @param program the source program + explicit State(const Program* program) : src(program) {} - /// Performs the transformation - void Run() { - auto& sem = ctx.src->Sem(); + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + Transform::ApplyResult Run() { + auto& sem = src->Sem(); // Maps a variable to its new name. - std::unordered_map renamed_to; + utils::Hashmap renamed_to; auto rename = [&](const sem::Variable* v) -> const ast::Variable* { auto* decl = v->Declaration(); - auto name = ctx.src->Symbols().NameFor(decl->symbol); - auto symbol = ctx.dst->Symbols().New(name); - renamed_to.emplace(v, symbol); + auto name = src->Symbols().NameFor(decl->symbol); + auto symbol = b.Symbols().New(name); + renamed_to.Add(v, symbol); auto source = ctx.Clone(decl->source); auto* type = ctx.Clone(decl->type); @@ -57,20 +62,20 @@ struct Unshadow::State { return Switch( decl, // [&](const ast::Var* var) { - return ctx.dst->Var(source, symbol, type, var->declared_address_space, - var->declared_access, initializer, attributes); + return b.Var(source, symbol, type, var->declared_address_space, + var->declared_access, initializer, attributes); }, [&](const ast::Let*) { - return ctx.dst->Let(source, symbol, type, initializer, attributes); + return b.Let(source, symbol, type, initializer, attributes); }, [&](const ast::Const*) { - return ctx.dst->Const(source, symbol, type, initializer, attributes); + return b.Const(source, symbol, type, initializer, attributes); }, - [&](const ast::Parameter*) { - return ctx.dst->Param(source, symbol, type, attributes); + [&](const ast::Parameter*) { // + return b.Param(source, symbol, type, attributes); }, [&](Default) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "unexpected variable type: " << decl->TypeInfo().name; return nullptr; }); @@ -92,14 +97,15 @@ struct Unshadow::State { ctx.ReplaceAll( [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* { if (auto* user = sem.Get(ident)) { - auto it = renamed_to.find(user->Variable()); - if (it != renamed_to.end()) { - return ctx.dst->Expr(it->second); + if (auto* renamed = renamed_to.Find(user->Variable())) { + return b.Expr(*renamed); } } return nullptr; }); + ctx.Clone(); + return Program(std::move(b)); } }; @@ -107,8 +113,8 @@ Unshadow::Unshadow() = default; Unshadow::~Unshadow() = default; -void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - State(ctx).Run(); +Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const { + return State(src).Run(); } } // namespace tint::transform diff --git a/src/tint/transform/unshadow.h b/src/tint/transform/unshadow.h index 5ffe8399b9..8ebf105cd6 100644 --- a/src/tint/transform/unshadow.h +++ b/src/tint/transform/unshadow.h @@ -29,16 +29,13 @@ class Unshadow final : public Castable { /// Destructor ~Unshadow() override; - protected: - struct State; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + private: + struct State; }; } // namespace tint::transform diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc index 4e20d55325..068fe35989 100644 --- a/src/tint/transform/unwind_discard_functions.cc +++ b/src/tint/transform/unwind_discard_functions.cc @@ -35,7 +35,51 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions); namespace tint::transform { namespace { -class State { +bool ShouldRun(const Program* program) { + auto& sem = program->Sem(); + for (auto* f : program->AST().Functions()) { + if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { + return true; + } + } + return false; +} + +} // namespace + +/// PIMPL state for the transform +struct UnwindDiscardFunctions::State { + /// Constructor + /// @param ctx_in the context + explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} + + /// Runs the transform + void Run() { + ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { + // Iterate block statements and replace them as needed. + for (auto* stmt : block->statements) { + if (auto* new_stmt = Statement(stmt)) { + ctx.Replace(stmt, new_stmt); + } + + // Handle for loops, as they are the only other AST node that + // contains statements outside of BlockStatements. + if (auto* fl = stmt->As()) { + if (auto* new_stmt = Statement(fl->initializer)) { + ctx.Replace(fl->initializer, new_stmt); + } + if (auto* new_stmt = Statement(fl->continuing)) { + // NOTE: Should never reach here as we cannot discard in a + // continuing block. + ctx.Replace(fl->continuing, new_stmt); + } + } + } + + return nullptr; + }); + } + private: CloneContext& ctx; ProgramBuilder& b; @@ -163,7 +207,7 @@ class State { // Returns true if `stmt` is a for-loop initializer statement. bool IsForLoopInitStatement(const ast::Statement* stmt) { if (auto* sem_stmt = sem.Get(stmt)) { - if (auto* sem_fl = As(sem_stmt->Parent())) { + if (auto* sem_fl = tint::As(sem_stmt->Parent())) { return sem_fl->Declaration()->initializer == stmt; } } @@ -305,60 +349,26 @@ class State { return TryInsertAfter(s, sem_expr); }); } - - public: - /// Constructor - /// @param ctx_in the context - explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {} - - /// Runs the transform - void Run() { - ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { - // Iterate block statements and replace them as needed. - for (auto* stmt : block->statements) { - if (auto* new_stmt = Statement(stmt)) { - ctx.Replace(stmt, new_stmt); - } - - // Handle for loops, as they are the only other AST node that - // contains statements outside of BlockStatements. - if (auto* fl = stmt->As()) { - if (auto* new_stmt = Statement(fl->initializer)) { - ctx.Replace(fl->initializer, new_stmt); - } - if (auto* new_stmt = Statement(fl->continuing)) { - // NOTE: Should never reach here as we cannot discard in a - // continuing block. - ctx.Replace(fl->continuing, new_stmt); - } - } - } - - return nullptr; - }); - - ctx.Clone(); - } }; -} // namespace - UnwindDiscardFunctions::UnwindDiscardFunctions() = default; UnwindDiscardFunctions::~UnwindDiscardFunctions() = default; -void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +Transform::ApplyResult UnwindDiscardFunctions::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + State state(ctx); state.Run(); -} -bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const { - auto& sem = program->Sem(); - for (auto* f : program->AST().Functions()) { - if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) { - return true; - } - } - return false; + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/unwind_discard_functions.h b/src/tint/transform/unwind_discard_functions.h index 105a9d8d39..7614c2759b 100644 --- a/src/tint/transform/unwind_discard_functions.h +++ b/src/tint/transform/unwind_discard_functions.h @@ -44,19 +44,13 @@ class UnwindDiscardFunctions final : public CastableStmt(), std::move(builder))) { + return false; + } + } else { + auto builder = [this, expr, name] { + return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr))); + }; + if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { + return false; + } + } + + // Replace the initializer expression with a reference to the let + ctx.Replace(expr, b.Expr(name)); + return true; + } + + /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*) + bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { + if (stmt) { + auto builder = [stmt] { return stmt; }; + return InsertBeforeImpl(before_stmt, std::move(builder)); + } + return InsertBeforeImpl(before_stmt, Decompose{}); + } + + /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&) + bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) { + return InsertBeforeImpl(before_stmt, std::move(builder)); + } + + /// @copydoc HoistToDeclBefore::Prepare() + bool Prepare(const sem::Expression* before_expr) { + return InsertBefore(before_expr->Stmt(), nullptr); + } + + private: CloneContext& ctx; ProgramBuilder& b; @@ -215,6 +267,8 @@ class HoistToDeclBefore::State { template bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) { + (void)builder; // Avoid 'unused parameter' warning due to 'if constexpr' + auto* ip = before_stmt->Declaration(); auto* else_if = before_stmt->As(); @@ -299,58 +353,6 @@ class HoistToDeclBefore::State { << "unhandled expression parent statement type: " << parent->TypeInfo().name; return false; } - - public: - /// Constructor - /// @param ctx_in the clone context - explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {} - - /// @copydoc HoistToDeclBefore::Add() - bool Add(const sem::Expression* before_expr, - const ast::Expression* expr, - bool as_let, - const char* decl_name) { - auto name = b.Symbols().New(decl_name); - - if (as_let) { - auto builder = [this, expr, name] { - return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr))); - }; - if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { - return false; - } - } else { - auto builder = [this, expr, name] { - return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr))); - }; - if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) { - return false; - } - } - - // Replace the initializer expression with a reference to the let - ctx.Replace(expr, b.Expr(name)); - return true; - } - - /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*) - bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { - if (stmt) { - auto builder = [stmt] { return stmt; }; - return InsertBeforeImpl(before_stmt, std::move(builder)); - } - return InsertBeforeImpl(before_stmt, Decompose{}); - } - - /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&) - bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) { - return InsertBeforeImpl(before_stmt, std::move(builder)); - } - - /// @copydoc HoistToDeclBefore::Prepare() - bool Prepare(const sem::Expression* before_expr) { - return InsertBefore(before_expr->Stmt(), nullptr); - } }; HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique(ctx)) {} diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h index b2e993e398..d9a8a8a632 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.h +++ b/src/tint/transform/utils/hoist_to_decl_before.h @@ -77,7 +77,7 @@ class HoistToDeclBefore { bool Prepare(const sem::Expression* before_expr); private: - class State; + struct State; std::unique_ptr state_; }; diff --git a/src/tint/transform/var_for_dynamic_index.cc b/src/tint/transform/var_for_dynamic_index.cc index d30831e14b..81af013941 100644 --- a/src/tint/transform/var_for_dynamic_index.cc +++ b/src/tint/transform/var_for_dynamic_index.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "src/tint/transform/var_for_dynamic_index.h" + +#include + #include "src/tint/program_builder.h" #include "src/tint/transform/utils/hoist_to_decl_before.h" @@ -22,7 +25,12 @@ VarForDynamicIndex::VarForDynamicIndex() = default; VarForDynamicIndex::~VarForDynamicIndex() = default; -void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src, + const DataMap&, + DataMap&) const { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + HoistToDeclBefore hoist_to_decl_before(ctx); // Extracts array and matrix values that are dynamically indexed to a @@ -30,7 +38,7 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) { auto* index_expr = access_expr->index; auto* object_expr = access_expr->object; - auto& sem = ctx.src->Sem(); + auto& sem = src->Sem(); if (sem.Get(index_expr)->ConstantValue()) { // Index expression resolves to a compile time value. @@ -49,15 +57,21 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index"); }; - for (auto* node : ctx.src->ASTNodes().Objects()) { + bool index_accessor_found = false; + for (auto* node : src->ASTNodes().Objects()) { if (auto* access_expr = node->As()) { if (!dynamic_index_to_var(access_expr)) { - return; + return Program(std::move(b)); } + index_accessor_found = true; } } + if (!index_accessor_found) { + return SkipTransform; + } ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/var_for_dynamic_index.h b/src/tint/transform/var_for_dynamic_index.h index 39ef2f2544..070a2cd1cf 100644 --- a/src/tint/transform/var_for_dynamic_index.h +++ b/src/tint/transform/var_for_dynamic_index.h @@ -31,14 +31,10 @@ class VarForDynamicIndex : public Transform { /// Destructor ~VarForDynamicIndex() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/vectorize_matrix_conversions.cc b/src/tint/transform/vectorize_matrix_conversions.cc index 576b8857df..94fbdf303b 100644 --- a/src/tint/transform/vectorize_matrix_conversions.cc +++ b/src/tint/transform/vectorize_matrix_conversions.cc @@ -30,11 +30,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeMatrixConversions); namespace tint::transform { -VectorizeMatrixConversions::VectorizeMatrixConversions() = default; +namespace { -VectorizeMatrixConversions::~VectorizeMatrixConversions() = default; - -bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { for (auto* node : program->ASTNodes().Objects()) { if (auto* sem = program->Sem().Get(node)) { if (auto* call = sem->UnwrapMaterialize()->As()) { @@ -50,14 +48,29 @@ bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap return false; } -void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +} // namespace + +VectorizeMatrixConversions::VectorizeMatrixConversions() = default; + +VectorizeMatrixConversions::~VectorizeMatrixConversions() = default; + +Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + using HelperFunctionKey = utils::UnorderedKeyWrapper>; std::unordered_map matrix_convs; ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { - auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As(); + auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As(); auto* ty_conv = call->Target()->As(); if (!ty_conv) { return nullptr; @@ -72,16 +85,16 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap& return nullptr; } - auto& src = args[0]; + auto& matrix = args[0]; - auto* src_type = args[0]->Type()->UnwrapRef()->As(); + auto* src_type = matrix->Type()->UnwrapRef()->As(); if (!src_type) { return nullptr; } // The source and destination type of a matrix conversion must have a same shape. if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "source and destination matrix has different shape in matrix conversion"; return nullptr; } @@ -90,47 +103,45 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap& utils::Vector columns; for (uint32_t c = 0; c < dst_type->columns(); c++) { auto* src_matrix_expr = src_expression_builder(); - auto* src_column_expr = - ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c))); - columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), - src_column_expr)); + auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c))); + columns.Push( + b.Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr)); } - return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns); + return b.Construct(CreateASTTypeFor(ctx, dst_type), columns); }; // Replace the matrix conversion to column vector conversions and a matrix construction. - if (!src->HasSideEffects()) { + if (!matrix->HasSideEffects()) { // Simply use the argument's declaration if it has no side effects. return build_vectorized_conversion_expression([&]() { // - return ctx.Clone(src->Declaration()); + return ctx.Clone(matrix->Declaration()); }); } else { // If has side effects, use a helper function. auto fn = utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] { - auto name = - ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) + - "x" + std::to_string(src_type->rows()) + "_" + - ctx.dst->FriendlyName(src_type->type()) + "_" + - ctx.dst->FriendlyName(dst_type->type())); - ctx.dst->Func( - name, - utils::Vector{ - ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)), - }, - CreateASTTypeFor(ctx, dst_type), - utils::Vector{ - ctx.dst->Return(build_vectorized_conversion_expression([&]() { // - return ctx.dst->Expr("value"); - })), - }); + auto name = b.Symbols().New( + "convert_mat" + std::to_string(src_type->columns()) + "x" + + std::to_string(src_type->rows()) + "_" + b.FriendlyName(src_type->type()) + + "_" + b.FriendlyName(dst_type->type())); + b.Func(name, + utils::Vector{ + b.Param("value", CreateASTTypeFor(ctx, src_type)), + }, + CreateASTTypeFor(ctx, dst_type), + utils::Vector{ + b.Return(build_vectorized_conversion_expression([&]() { // + return b.Expr("value"); + })), + }); return name; }); - return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration())); + return b.Call(fn, ctx.Clone(args[0]->Declaration())); } }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/vectorize_matrix_conversions.h b/src/tint/transform/vectorize_matrix_conversions.h index f16467c3dc..c86240c1d9 100644 --- a/src/tint/transform/vectorize_matrix_conversions.h +++ b/src/tint/transform/vectorize_matrix_conversions.h @@ -28,19 +28,10 @@ class VectorizeMatrixConversions final : public CastableASTNodes().Objects()) { if (auto* call = program->Sem().Get(node)) { if (call->Target()->Is() && call->Type()->Is()) { @@ -46,11 +43,26 @@ bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const return false; } -void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +} // namespace + +VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default; + +VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default; + +Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + std::unordered_map scalar_inits; ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { - auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As(); + auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As(); auto* ty_init = call->Target()->As(); if (!ty_init) { return nullptr; @@ -87,10 +99,10 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D } // Construct the column vector. - columns.Push(ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), - std::move(row_values))); + columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), + std::move(row_values))); } - return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns); + return b.Construct(CreateASTTypeFor(ctx, mat_type), columns); }; if (args.Length() == 1) { @@ -98,23 +110,22 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D // This is done to ensure that the single argument value is only evaluated once, and // with the correct expression evaluation order. auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] { - auto name = - ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" + - std::to_string(mat_type->rows())); - ctx.dst->Func(name, - utils::Vector{ - // Single scalar parameter - ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())), - }, - CreateASTTypeFor(ctx, mat_type), - utils::Vector{ - ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { // - return ctx.dst->Expr("value"); - })), - }); + auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) + + "x" + std::to_string(mat_type->rows())); + b.Func(name, + utils::Vector{ + // Single scalar parameter + b.Param("value", CreateASTTypeFor(ctx, mat_type->type())), + }, + CreateASTTypeFor(ctx, mat_type), + utils::Vector{ + b.Return(build_mat([&](uint32_t, uint32_t) { // + return b.Expr("value"); + })), + }); return name; }); - return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration())); + return b.Call(fn, ctx.Clone(args[0]->Declaration())); } if (args.Length() == mat_type->columns() * mat_type->rows()) { @@ -123,12 +134,13 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D }); } - TINT_ICE(Transform, ctx.dst->Diagnostics()) + TINT_ICE(Transform, b.Diagnostics()) << "matrix initializer has unexpected number of arguments"; return nullptr; }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/vectorize_scalar_matrix_initializers.h b/src/tint/transform/vectorize_scalar_matrix_initializers.h index 342754ac06..f9c0164e2f 100644 --- a/src/tint/transform/vectorize_scalar_matrix_initializers.h +++ b/src/tint/transform/vectorize_scalar_matrix_initializers.h @@ -29,19 +29,10 @@ class VectorizeScalarMatrixInitializers final /// Destructor ~VectorizeScalarMatrixInitializers() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc index 00b5a065ea..d5ee424262 100644 --- a/src/tint/transform/vertex_pulling.cc +++ b/src/tint/transform/vertex_pulling.cc @@ -201,13 +201,46 @@ DataType DataTypeOf(VertexFormat format) { return {BaseType::kInvalid, 0}; } -struct State { - State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {} - State(const State&) = default; - ~State() = default; +} // namespace - /// LocationReplacement describes an ast::Variable replacement for a - /// location input. +/// PIMPL state for the transform +struct VertexPulling::State { + /// Constructor + /// @param program the source program + /// @param c the VertexPulling config + State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {} + + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + ApplyResult Run() { + // Find entry point + const ast::Function* func = nullptr; + for (auto* fn : src->AST().Functions()) { + if (fn->PipelineStage() == ast::PipelineStage::kVertex) { + if (func != nullptr) { + b.Diagnostics().add_error( + diag::System::Transform, + "VertexPulling found more than one vertex entry point"); + return Program(std::move(b)); + } + func = fn; + } + } + if (func == nullptr) { + b.Diagnostics().add_error(diag::System::Transform, + "Vertex stage entry point not found"); + return Program(std::move(b)); + } + + AddVertexStorageBuffers(); + Process(func); + + ctx.Clone(); + return Program(std::move(b)); + } + + private: + /// LocationReplacement describes an ast::Variable replacement for a location input. struct LocationReplacement { /// The variable to replace in the source Program ast::Variable* from; @@ -215,13 +248,22 @@ struct State { ast::Variable* to; }; + /// LocationInfo describes an input location struct LocationInfo { + /// A builder that builds the expression that resolves to the (transformed) input location std::function expr; + /// The store type of the location variable const sem::Type* type; }; - CloneContext& ctx; + /// The source program + const Program* const src; + /// The transform config VertexPulling::Config const cfg; + /// The target program builder + ProgramBuilder b; + /// The clone context + CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; std::unordered_map location_info; std::function vertex_index_expr = nullptr; std::function instance_index_expr = nullptr; @@ -235,7 +277,7 @@ struct State { Symbol GetVertexBufferName(uint32_t index) { return utils::GetOrCreate(vertex_buffer_names, index, [&] { static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_"; - return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); + return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index)); }); } @@ -243,7 +285,7 @@ struct State { Symbol GetStructBufferName() { if (!struct_buffer_name.IsValid()) { static const char kStructBufferName[] = "tint_vertex_data"; - struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName); + struct_buffer_name = b.Symbols().New(kStructBufferName); } return struct_buffer_name; } @@ -252,21 +294,19 @@ struct State { void AddVertexStorageBuffers() { // Creating the struct type static const char kStructName[] = "TintVertexData"; - auto* struct_type = - ctx.dst->Structure(ctx.dst->Symbols().New(kStructName), - utils::Vector{ - ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array()), - }); + auto* struct_type = b.Structure(b.Symbols().New(kStructName), + utils::Vector{ + b.Member(GetStructBufferName(), b.ty.array()), + }); for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { // The decorated variable with struct type - ctx.dst->GlobalVar(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type), - ast::AddressSpace::kStorage, ast::Access::kRead, - ctx.dst->Binding(AInt(i)), ctx.dst->Group(AInt(cfg.pulling_group))); + b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage, + ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group))); } } /// Creates and returns the assignment to the variables from the buffers - ast::BlockStatement* CreateVertexPullingPreamble() { + const ast::BlockStatement* CreateVertexPullingPreamble() { // Assign by looking at the vertex descriptor to find attributes with // matching location. @@ -276,7 +316,7 @@ struct State { const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx]; if ((buffer_layout.array_stride & 3) != 0) { - ctx.dst->Diagnostics().add_error( + b.Diagnostics().add_error( diag::System::Transform, "WebGPU requires that vertex stride must be a multiple of 4 bytes, " "but VertexPulling array stride for buffer " + @@ -292,15 +332,15 @@ struct State { // buffer_array_base is the base array offset for all the vertex // attributes. These are units of uint (4 bytes). auto buffer_array_base = - ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); + b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx)); auto* attribute_offset = index_expr; if (buffer_layout.array_stride != 4) { - attribute_offset = ctx.dst->Mul(index_expr, u32(buffer_layout.array_stride / 4u)); + attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u)); } // let pulling_offset_n = - stmts.Push(ctx.dst->Decl(ctx.dst->Let(buffer_array_base, attribute_offset))); + stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset))); for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) { auto it = location_info.find(attribute_desc.shader_location); @@ -320,8 +360,8 @@ struct State { err << "VertexAttributeDescriptor for location " << std::to_string(attribute_desc.shader_location) << " has format " << attribute_desc.format << " but shader expects " - << var.type->FriendlyName(ctx.src->Symbols()); - ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str()); + << var.type->FriendlyName(src->Symbols()); + b.Diagnostics().add_error(diag::System::Transform, err.str()); return nullptr; } @@ -337,16 +377,16 @@ struct State { // WGSL variable vector width is smaller than the loaded vector width switch (var_dt.width) { case 1: - value = ctx.dst->MemberAccessor(fetch, "x"); + value = b.MemberAccessor(fetch, "x"); break; case 2: - value = ctx.dst->MemberAccessor(fetch, "xy"); + value = b.MemberAccessor(fetch, "xy"); break; case 3: - value = ctx.dst->MemberAccessor(fetch, "xyz"); + value = b.MemberAccessor(fetch, "xyz"); break; default: - TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width; + TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width; return nullptr; } } else if (var_dt.width > fmt_dt.width) { @@ -355,32 +395,32 @@ struct State { utils::Vector values{fetch}; switch (var_dt.base_type) { case BaseType::kI32: - ty = ctx.dst->ty.i32(); + ty = b.ty.i32(); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { - values.Push(ctx.dst->Expr((i == 3) ? 1_i : 0_i)); + values.Push(b.Expr((i == 3) ? 1_i : 0_i)); } break; case BaseType::kU32: - ty = ctx.dst->ty.u32(); + ty = b.ty.u32(); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { - values.Push(ctx.dst->Expr((i == 3) ? 1_u : 0_u)); + values.Push(b.Expr((i == 3) ? 1_u : 0_u)); } break; case BaseType::kF32: - ty = ctx.dst->ty.f32(); + ty = b.ty.f32(); for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) { - values.Push(ctx.dst->Expr((i == 3) ? 1_f : 0_f)); + values.Push(b.Expr((i == 3) ? 1_f : 0_f)); } break; default: - TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type; + TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type; return nullptr; } - value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values); + value = b.Construct(b.ty.vec(ty, var_dt.width), values); } // Assign the value to the WGSL variable - stmts.Push(ctx.dst->Assign(var.expr(), value)); + stmts.Push(b.Assign(var.expr(), value)); } } @@ -388,7 +428,7 @@ struct State { return nullptr; } - return ctx.dst->create(std::move(stmts)); + return b.Block(std::move(stmts)); } /// Generates an expression reading from a buffer a specific format. @@ -407,7 +447,7 @@ struct State { }; // Returns a i32 loaded from buffer_base + offset. - auto load_i32 = [&] { return ctx.dst->Bitcast(load_u32()); }; + auto load_i32 = [&] { return b.Bitcast(load_u32()); }; // Returns a u32 loaded from buffer_base + offset + 4. auto load_next_u32 = [&] { @@ -415,7 +455,7 @@ struct State { }; // Returns a i32 loaded from buffer_base + offset + 4. - auto load_next_i32 = [&] { return ctx.dst->Bitcast(load_next_u32()); }; + auto load_next_i32 = [&] { return b.Bitcast(load_next_u32()); }; // Returns a u16 loaded from offset, packed in the high 16 bits of a u32. // The low 16 bits are 0. @@ -427,17 +467,17 @@ struct State { LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); switch (offset & 3) { case 0: - return ctx.dst->Shl(low_u32, 16_u); + return b.Shl(low_u32, 16_u); case 1: - return ctx.dst->And(ctx.dst->Shl(low_u32, 8_u), 0xffff0000_u); + return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u); case 2: - return ctx.dst->And(low_u32, 0xffff0000_u); + return b.And(low_u32, 0xffff0000_u); default: { // 3: auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, VertexFormat::kUint32); - auto* shr = ctx.dst->Shr(low_u32, 8_u); - auto* shl = ctx.dst->Shl(high_u32, 24_u); - return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000_u); + auto* shr = b.Shr(low_u32, 8_u); + auto* shl = b.Shl(high_u32, 24_u); + return b.And(b.Or(shl, shr), 0xffff0000_u); } } }; @@ -450,24 +490,24 @@ struct State { LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32); switch (offset & 3) { case 0: - return ctx.dst->And(low_u32, 0xffff_u); + return b.And(low_u32, 0xffff_u); case 1: - return ctx.dst->And(ctx.dst->Shr(low_u32, 8_u), 0xffff_u); + return b.And(b.Shr(low_u32, 8_u), 0xffff_u); case 2: - return ctx.dst->Shr(low_u32, 16_u); + return b.Shr(low_u32, 16_u); default: { // 3: auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer, VertexFormat::kUint32); - auto* shr = ctx.dst->Shr(low_u32, 24_u); - auto* shl = ctx.dst->Shl(high_u32, 8_u); - return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff_u); + auto* shr = b.Shr(low_u32, 24_u); + auto* shl = b.Shl(high_u32, 8_u); + return b.And(b.Or(shl, shr), 0xffff_u); } } }; // Returns a i16 loaded from offset, packed in the high 16 bits of a u32. // The low 16 bits are 0. - auto load_i16_h = [&] { return ctx.dst->Bitcast(load_u16_h()); }; + auto load_i16_h = [&] { return b.Bitcast(load_u16_h()); }; // Assumptions are made that alignment must be at least as large as the size // of a single component. @@ -480,128 +520,121 @@ struct State { // Vectors of basic primitives case VertexFormat::kUint32x2: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), - VertexFormat::kUint32, 2); + return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2); case VertexFormat::kUint32x3: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), - VertexFormat::kUint32, 3); + return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3); case VertexFormat::kUint32x4: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(), - VertexFormat::kUint32, 4); + return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4); case VertexFormat::kSint32x2: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), - VertexFormat::kSint32, 2); + return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2); case VertexFormat::kSint32x3: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), - VertexFormat::kSint32, 3); + return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3); case VertexFormat::kSint32x4: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(), - VertexFormat::kSint32, 4); + return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4); case VertexFormat::kFloat32x2: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), - VertexFormat::kFloat32, 2); + return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, + 2); case VertexFormat::kFloat32x3: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), - VertexFormat::kFloat32, 3); + return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, + 3); case VertexFormat::kFloat32x4: - return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(), - VertexFormat::kFloat32, 4); + return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32, + 4); case VertexFormat::kUint8x2: { // yyxx0000, yyxx0000 - auto* u16s = ctx.dst->vec2(load_u16_h()); + auto* u16s = b.vec2(load_u16_h()); // xx000000, yyxx0000 - auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2(8_u, 0_u)); + auto* shl = b.Shl(u16s, b.vec2(8_u, 0_u)); // 000000xx, 000000yy - return ctx.dst->Shr(shl, ctx.dst->vec2(24_u)); + return b.Shr(shl, b.vec2(24_u)); } case VertexFormat::kUint8x4: { // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx - auto* u32s = ctx.dst->vec4(load_u32()); + auto* u32s = b.vec4(load_u32()); // xx000000, yyxx0000, zzyyxx00, wwzzyyxx - auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4(24_u, 16_u, 8_u, 0_u)); + auto* shl = b.Shl(u32s, b.vec4(24_u, 16_u, 8_u, 0_u)); // 000000xx, 000000yy, 000000zz, 000000ww - return ctx.dst->Shr(shl, ctx.dst->vec4(24_u)); + return b.Shr(shl, b.vec4(24_u)); } case VertexFormat::kUint16x2: { // yyyyxxxx, yyyyxxxx - auto* u32s = ctx.dst->vec2(load_u32()); + auto* u32s = b.vec2(load_u32()); // xxxx0000, yyyyxxxx - auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2(16_u, 0_u)); + auto* shl = b.Shl(u32s, b.vec2(16_u, 0_u)); // 0000xxxx, 0000yyyy - return ctx.dst->Shr(shl, ctx.dst->vec2(16_u)); + return b.Shr(shl, b.vec2(16_u)); } case VertexFormat::kUint16x4: { // yyyyxxxx, wwwwzzzz - auto* u32s = ctx.dst->vec2(load_u32(), load_next_u32()); + auto* u32s = b.vec2(load_u32(), load_next_u32()); // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz - auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy"); + auto* xxyy = b.MemberAccessor(u32s, "xxyy"); // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz - auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4(16_u, 0_u, 16_u, 0_u)); + auto* shl = b.Shl(xxyy, b.vec4(16_u, 0_u, 16_u, 0_u)); // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww - return ctx.dst->Shr(shl, ctx.dst->vec4(16_u)); + return b.Shr(shl, b.vec4(16_u)); } case VertexFormat::kSint8x2: { // yyxx0000, yyxx0000 - auto* i16s = ctx.dst->vec2(load_i16_h()); + auto* i16s = b.vec2(load_i16_h()); // xx000000, yyxx0000 - auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2(8_u, 0_u)); + auto* shl = b.Shl(i16s, b.vec2(8_u, 0_u)); // ssssssxx, ssssssyy - return ctx.dst->Shr(shl, ctx.dst->vec2(24_u)); + return b.Shr(shl, b.vec2(24_u)); } case VertexFormat::kSint8x4: { // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx - auto* i32s = ctx.dst->vec4(load_i32()); + auto* i32s = b.vec4(load_i32()); // xx000000, yyxx0000, zzyyxx00, wwzzyyxx - auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4(24_u, 16_u, 8_u, 0_u)); + auto* shl = b.Shl(i32s, b.vec4(24_u, 16_u, 8_u, 0_u)); // ssssssxx, ssssssyy, sssssszz, ssssssww - return ctx.dst->Shr(shl, ctx.dst->vec4(24_u)); + return b.Shr(shl, b.vec4(24_u)); } case VertexFormat::kSint16x2: { // yyyyxxxx, yyyyxxxx - auto* i32s = ctx.dst->vec2(load_i32()); + auto* i32s = b.vec2(load_i32()); // xxxx0000, yyyyxxxx - auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2(16_u, 0_u)); + auto* shl = b.Shl(i32s, b.vec2(16_u, 0_u)); // ssssxxxx, ssssyyyy - return ctx.dst->Shr(shl, ctx.dst->vec2(16_u)); + return b.Shr(shl, b.vec2(16_u)); } case VertexFormat::kSint16x4: { // yyyyxxxx, wwwwzzzz - auto* i32s = ctx.dst->vec2(load_i32(), load_next_i32()); + auto* i32s = b.vec2(load_i32(), load_next_i32()); // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz - auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy"); + auto* xxyy = b.MemberAccessor(i32s, "xxyy"); // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz - auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4(16_u, 0_u, 16_u, 0_u)); + auto* shl = b.Shl(xxyy, b.vec4(16_u, 0_u, 16_u, 0_u)); // ssssxxxx, ssssyyyy, sssszzzz, sssswwww - return ctx.dst->Shr(shl, ctx.dst->vec4(16_u)); + return b.Shr(shl, b.vec4(16_u)); } case VertexFormat::kUnorm8x2: - return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy"); + return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy"); case VertexFormat::kSnorm8x2: - return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy"); + return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy"); case VertexFormat::kUnorm8x4: - return ctx.dst->Call("unpack4x8unorm", load_u32()); + return b.Call("unpack4x8unorm", load_u32()); case VertexFormat::kSnorm8x4: - return ctx.dst->Call("unpack4x8snorm", load_u32()); + return b.Call("unpack4x8snorm", load_u32()); case VertexFormat::kUnorm16x2: - return ctx.dst->Call("unpack2x16unorm", load_u32()); + return b.Call("unpack2x16unorm", load_u32()); case VertexFormat::kSnorm16x2: - return ctx.dst->Call("unpack2x16snorm", load_u32()); + return b.Call("unpack2x16snorm", load_u32()); case VertexFormat::kFloat16x2: - return ctx.dst->Call("unpack2x16float", load_u32()); + return b.Call("unpack2x16float", load_u32()); case VertexFormat::kUnorm16x4: - return ctx.dst->vec4(ctx.dst->Call("unpack2x16unorm", load_u32()), - ctx.dst->Call("unpack2x16unorm", load_next_u32())); + return b.vec4(b.Call("unpack2x16unorm", load_u32()), + b.Call("unpack2x16unorm", load_next_u32())); case VertexFormat::kSnorm16x4: - return ctx.dst->vec4(ctx.dst->Call("unpack2x16snorm", load_u32()), - ctx.dst->Call("unpack2x16snorm", load_next_u32())); + return b.vec4(b.Call("unpack2x16snorm", load_u32()), + b.Call("unpack2x16snorm", load_next_u32())); case VertexFormat::kFloat16x4: - return ctx.dst->vec4(ctx.dst->Call("unpack2x16float", load_u32()), - ctx.dst->Call("unpack2x16float", load_next_u32())); + return b.vec4(b.Call("unpack2x16float", load_u32()), + b.Call("unpack2x16float", load_next_u32())); } - TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) - << "format " << static_cast(format); + TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast(format); return nullptr; } @@ -623,12 +656,12 @@ struct State { const ast ::Expression* index = nullptr; if (offset > 0) { - index = ctx.dst->Add(array_base, u32(offset / 4)); + index = b.Add(array_base, u32(offset / 4)); } else { - index = ctx.dst->Expr(array_base); + index = b.Expr(array_base); } - u = ctx.dst->IndexAccessor( - ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); + u = b.IndexAccessor( + b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index); } else { // Unaligned load @@ -639,22 +672,22 @@ struct State { uint32_t shift = 8u * (offset & 3u); - auto* low_shr = ctx.dst->Shr(low, u32(shift)); - auto* high_shl = ctx.dst->Shl(high, u32(32u - shift)); - u = ctx.dst->Or(low_shr, high_shl); + auto* low_shr = b.Shr(low, u32(shift)); + auto* high_shl = b.Shl(high, u32(32u - shift)); + u = b.Or(low_shr, high_shl); } switch (format) { case VertexFormat::kUint32: return u; case VertexFormat::kSint32: - return ctx.dst->Bitcast(ctx.dst->ty.i32(), u); + return b.Bitcast(b.ty.i32(), u); case VertexFormat::kFloat32: - return ctx.dst->Bitcast(ctx.dst->ty.f32(), u); + return b.Bitcast(b.ty.f32(), u); default: break; } - TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) + TINT_UNREACHABLE(Transform, b.Diagnostics()) << "invalid format for LoadPrimitive" << static_cast(format); return nullptr; } @@ -682,8 +715,7 @@ struct State { expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format)); } - return ctx.dst->Construct(ctx.dst->create(base_type, count), - std::move(expr_list)); + return b.Construct(b.create(base_type, count), std::move(expr_list)); } /// Process a non-struct entry point parameter. @@ -696,34 +728,30 @@ struct State { // Create a function-scope variable to replace the parameter. auto func_var_sym = ctx.Clone(param->symbol); auto* func_var_type = ctx.Clone(param->type); - auto* func_var = ctx.dst->Var(func_var_sym, func_var_type); - ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var)); + auto* func_var = b.Var(func_var_sym, func_var_type); + ctx.InsertFront(func->body->statements, b.Decl(func_var)); // Capture mapping from location to the new variable. LocationInfo info; - info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); }; + info.expr = [this, func_var]() { return b.Expr(func_var); }; - auto* sem = ctx.src->Sem().Get(param); + auto* sem = src->Sem().Get(param); info.type = sem->Type(); if (!sem->Location().has_value()) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value"; + TINT_ICE(Transform, b.Diagnostics()) << "Location missing value"; return; } location_info[sem->Location().value()] = info; } else if (auto* builtin = ast::GetAttribute(param->attributes)) { // Check for existing vertex_index and instance_index builtins. if (builtin->builtin == ast::BuiltinValue::kVertexIndex) { - vertex_index_expr = [this, param]() { - return ctx.dst->Expr(ctx.Clone(param->symbol)); - }; + vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; } else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) { - instance_index_expr = [this, param]() { - return ctx.dst->Expr(ctx.Clone(param->symbol)); - }; + instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); }; } new_function_parameters.Push(ctx.Clone(param)); } else { - TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter"; + TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; } } @@ -746,7 +774,7 @@ struct State { for (auto* member : struct_ty->members) { auto member_sym = ctx.Clone(member->symbol); std::function member_expr = [this, param_sym, member_sym]() { - return ctx.dst->MemberAccessor(param_sym, member_sym); + return b.MemberAccessor(param_sym, member_sym); }; if (ast::HasAttribute(member->attributes)) { @@ -754,7 +782,7 @@ struct State { LocationInfo info; info.expr = member_expr; - auto* sem = ctx.src->Sem().Get(member); + auto* sem = src->Sem().Get(member); info.type = sem->Type(); TINT_ASSERT(Transform, sem->Location().has_value()); @@ -770,7 +798,7 @@ struct State { } members_to_clone.Push(member); } else { - TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter"; + TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; } } @@ -781,8 +809,8 @@ struct State { } // Create a function-scope variable to replace the parameter. - auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type)); - ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var)); + auto* func_var = b.Var(param_sym, ctx.Clone(param->type)); + ctx.InsertFront(func->body->statements, b.Decl(func_var)); if (!members_to_clone.IsEmpty()) { // Create a new struct without the location attributes. @@ -791,20 +819,20 @@ struct State { auto member_sym = ctx.Clone(member->symbol); auto* member_type = ctx.Clone(member->type); auto member_attrs = ctx.Clone(member->attributes); - new_members.Push(ctx.dst->Member(member_sym, member_type, std::move(member_attrs))); + new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs))); } - auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members); + auto* new_struct = b.Structure(b.Sym(), new_members); // Create a new function parameter with this struct. - auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct)); + auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct)); new_function_parameters.Push(new_param); // Copy values from the new parameter to the function-scope variable. for (auto* member : members_to_clone) { auto member_name = ctx.Clone(member->symbol); ctx.InsertFront(func->body->statements, - ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name), - ctx.dst->MemberAccessor(new_param, member_name))); + b.Assign(b.MemberAccessor(func_var, member_name), + b.MemberAccessor(new_param, member_name))); } } } @@ -818,7 +846,7 @@ struct State { // Process entry point parameters. for (auto* param : func->params) { - auto* sem = ctx.src->Sem().Get(param); + auto* sem = src->Sem().Get(param); if (auto* str = sem->Type()->As()) { ProcessStructParameter(func, param, str->Declaration()); } else { @@ -830,11 +858,11 @@ struct State { if (!vertex_index_expr) { for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { if (layout.step_mode == VertexStepMode::kVertex) { - auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index"); - new_function_parameters.Push(ctx.dst->Param( - name, ctx.dst->ty.u32(), - utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kVertexIndex)})); - vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); }; + auto name = b.Symbols().New("tint_pulling_vertex_index"); + new_function_parameters.Push( + b.Param(name, b.ty.u32(), + utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)})); + vertex_index_expr = [this, name]() { return b.Expr(name); }; break; } } @@ -842,11 +870,11 @@ struct State { if (!instance_index_expr) { for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { if (layout.step_mode == VertexStepMode::kInstance) { - auto name = ctx.dst->Symbols().New("tint_pulling_instance_index"); - new_function_parameters.Push(ctx.dst->Param( - name, ctx.dst->ty.u32(), - utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kInstanceIndex)})); - instance_index_expr = [this, name]() { return ctx.dst->Expr(name); }; + auto name = b.Symbols().New("tint_pulling_instance_index"); + new_function_parameters.Push( + b.Param(name, b.ty.u32(), + utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)})); + instance_index_expr = [this, name]() { return b.Expr(name); }; break; } } @@ -864,53 +892,24 @@ struct State { auto attrs = ctx.Clone(func->attributes); auto ret_attrs = ctx.Clone(func->return_type_attributes); auto* new_func = - ctx.dst->create(func->source, func_sym, new_function_parameters, - ret_type, body, std::move(attrs), std::move(ret_attrs)); + b.create(func->source, func_sym, new_function_parameters, ret_type, body, + std::move(attrs), std::move(ret_attrs)); ctx.Replace(func, new_func); } }; -} // namespace - VertexPulling::VertexPulling() = default; VertexPulling::~VertexPulling() = default; -void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { +Transform::ApplyResult VertexPulling::Apply(const Program* src, + const DataMap& inputs, + DataMap&) const { auto cfg = cfg_; if (auto* cfg_data = inputs.Get()) { cfg = *cfg_data; } - // Find entry point - const ast::Function* func = nullptr; - for (auto* fn : ctx.src->AST().Functions()) { - if (fn->PipelineStage() == ast::PipelineStage::kVertex) { - if (func != nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "VertexPulling found more than one vertex entry point"); - return; - } - func = fn; - } - } - if (func == nullptr) { - ctx.dst->Diagnostics().add_error(diag::System::Transform, - "Vertex stage entry point not found"); - return; - } - - // TODO(idanr): Need to check shader locations in descriptor cover all - // attributes - - // TODO(idanr): Make sure we covered all error cases, to guarantee the - // following stages will pass - - State state{ctx, cfg}; - state.AddVertexStorageBuffers(); - state.Process(func); - - ctx.Clone(); + return State{src, cfg}.Run(); } VertexPulling::Config::Config() = default; diff --git a/src/tint/transform/vertex_pulling.h b/src/tint/transform/vertex_pulling.h index 6dd35bc85a..c0f88a596e 100644 --- a/src/tint/transform/vertex_pulling.h +++ b/src/tint/transform/vertex_pulling.h @@ -171,16 +171,14 @@ class VertexPulling final : public Castable { /// Destructor ~VertexPulling() override; - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; private: + struct State; + Config cfg_; }; diff --git a/src/tint/transform/while_to_loop.cc b/src/tint/transform/while_to_loop.cc index 45944e67ae..d359d2e796 100644 --- a/src/tint/transform/while_to_loop.cc +++ b/src/tint/transform/while_to_loop.cc @@ -14,18 +14,17 @@ #include "src/tint/transform/while_to_loop.h" +#include + #include "src/tint/ast/break_statement.h" #include "src/tint/program_builder.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop); namespace tint::transform { +namespace { -WhileToLoop::WhileToLoop() = default; - -WhileToLoop::~WhileToLoop() = default; - -bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const { +bool ShouldRun(const Program* program) { for (auto* node : program->ASTNodes().Objects()) { if (node->Is()) { return true; @@ -34,20 +33,32 @@ bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const { return false; } -void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { +} // namespace + +WhileToLoop::WhileToLoop() = default; + +WhileToLoop::~WhileToLoop() = default; + +Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; + } + + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* { utils::Vector stmts; auto* cond = w->condition; // !condition - auto* not_cond = - ctx.dst->create(ast::UnaryOp::kNot, ctx.Clone(cond)); + auto* not_cond = b.Not(ctx.Clone(cond)); // { break; } - auto* break_body = ctx.dst->Block(ctx.dst->create()); + auto* break_body = b.Block(b.Break()); // if (!condition) { break; } - stmts.Push(ctx.dst->If(not_cond, break_body)); + stmts.Push(b.If(not_cond, break_body)); for (auto* stmt : w->body->statements) { stmts.Push(ctx.Clone(stmt)); @@ -55,13 +66,14 @@ void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { const ast::BlockStatement* continuing = nullptr; - auto* body = ctx.dst->Block(stmts); - auto* loop = ctx.dst->create(body, continuing); + auto* body = b.Block(stmts); + auto* loop = b.Loop(body, continuing); return loop; }); ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/while_to_loop.h b/src/tint/transform/while_to_loop.h index 4915d681e6..187799a845 100644 --- a/src/tint/transform/while_to_loop.h +++ b/src/tint/transform/while_to_loop.h @@ -29,19 +29,10 @@ class WhileToLoop final : public Castable { /// Destructor ~WhileToLoop() override; - /// @param program the program to inspect - /// @param data optional extra transform-specific input data - /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program, const DataMap& data = {}) const override; - - protected: - /// Runs the transform using the CloneContext built for transforming a - /// program. Run() is responsible for calling Clone() on the CloneContext. - /// @param ctx the CloneContext primed with the input program and - /// ProgramBuilder - /// @param inputs optional extra transform-specific input data - /// @param outputs optional extra transform-specific output data - void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override; + /// @copydoc Transform::Apply + ApplyResult Apply(const Program* program, + const DataMap& inputs, + DataMap& outputs) const override; }; } // namespace tint::transform diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc index ea654361c0..ed3584e968 100644 --- a/src/tint/transform/zero_init_workgroup_memory.cc +++ b/src/tint/transform/zero_init_workgroup_memory.cc @@ -31,10 +31,24 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory); namespace tint::transform { +namespace { + +bool ShouldRun(const Program* program) { + for (auto* global : program->AST().GlobalVariables()) { + if (auto* var = global->As()) { + if (var->declared_address_space == ast::AddressSpace::kWorkgroup) { + return true; + } + } + } + return false; +} + +} // namespace using StatementList = utils::Vector; -/// PIMPL state for the ZeroInitWorkgroupMemory transform +/// PIMPL state for the transform struct ZeroInitWorkgroupMemory::State { /// The clone context CloneContext& ctx; @@ -424,24 +438,24 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; -bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const { - for (auto* global : program->AST().GlobalVariables()) { - if (auto* var = global->As()) { - if (var->declared_address_space == ast::AddressSpace::kWorkgroup) { - return true; - } - } +Transform::ApplyResult ZeroInitWorkgroupMemory::Apply(const Program* src, + const DataMap&, + DataMap&) const { + if (!ShouldRun(src)) { + return SkipTransform; } - return false; -} -void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - for (auto* fn : ctx.src->AST().Functions()) { + ProgramBuilder b; + CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; + + for (auto* fn : src->AST().Functions()) { if (fn->PipelineStage() == ast::PipelineStage::kCompute) { State{ctx}.Run(fn); } } + ctx.Clone(); + return Program(std::move(b)); } } // namespace tint::transform diff --git a/src/tint/transform/zero_init_workgroup_memory.h b/src/tint/transform/zero_init_workgroup_memory.h index 07feaa895a..64f4da8f1c 100644 --- a/src/tint/transform/zero_init_workgroup_memory.h +++ b/src/tint/transform/zero_init_workgroup_memory.h @@ -30,19 +30,10 @@ class ZeroInitWorkgroupMemory final : public Castable plane0, Texture2D plane1, i return float4(color, 1.0f); } -int2 tint_clamp(int2 e, int2 low, int2 high) { - return min(max(e, low), high); -} - float3x4 tint_symbol_6(uint4 buffer[11], uint offset) { const uint scalar_offset = ((offset + 0u)) / 4; const uint scalar_offset_1 = ((offset + 16u)) / 4; diff --git a/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl index 6a6f107f03..a80ef5b69d 100644 --- a/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl +++ b/test/tint/bug/tint/1739.wgsl.expected.fxc.hlsl @@ -1,3 +1,7 @@ +int2 tint_clamp(int2 e, int2 low, int2 high) { + return min(max(e, low), high); +} + struct GammaTransferParams { float G; float A; @@ -46,10 +50,6 @@ float4 textureLoadExternal(Texture2D plane0, Texture2D plane1, i return float4(color, 1.0f); } -int2 tint_clamp(int2 e, int2 low, int2 high) { - return min(max(e, low), high); -} - float3x4 tint_symbol_6(uint4 buffer[11], uint offset) { const uint scalar_offset = ((offset + 0u)) / 4; const uint scalar_offset_1 = ((offset + 16u)) / 4; diff --git a/test/tint/bug/tint/1739.wgsl.expected.msl b/test/tint/bug/tint/1739.wgsl.expected.msl index e1802c56dd..ed7a995d9c 100644 --- a/test/tint/bug/tint/1739.wgsl.expected.msl +++ b/test/tint/bug/tint/1739.wgsl.expected.msl @@ -14,6 +14,10 @@ struct tint_array { T elements[N]; }; +int2 tint_clamp(int2 e, int2 low, int2 high) { + return min(max(e, low), high); +} + struct GammaTransferParams { /* 0x0000 */ float G; /* 0x0004 */ float A; @@ -57,10 +61,6 @@ float4 textureLoadExternal(texture2d plane0, texture2d tint_symbol_5 [[texture(0)]], texture2d tint_symbol_6 [[texture(1)]], const constant ExternalTextureParams* tint_symbol_7 [[buffer(0)]], texture2d tint_symbol_8 [[texture(2)]]) { int2 const tint_symbol_1 = tint_clamp(int2(10), int2(0), int2((uint2(uint2(tint_symbol_5.get_width(), tint_symbol_5.get_height())) - uint2(1u)))); float4 red = textureLoadExternal(tint_symbol_5, tint_symbol_6, tint_symbol_1, *(tint_symbol_7)); diff --git a/test/tint/bug/tint/1739.wgsl.expected.spvasm b/test/tint/bug/tint/1739.wgsl.expected.spvasm index 2cdbccb29f..2f01aa56df 100644 --- a/test/tint/bug/tint/1739.wgsl.expected.spvasm +++ b/test/tint/bug/tint/1739.wgsl.expected.spvasm @@ -5,7 +5,7 @@ ; Schema: 0 OpCapability Shader OpCapability ImageQuery - %25 = OpExtInstImport "GLSL.std.450" + %28 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %main "main" OpExecutionMode %main LocalSize 1 1 1 @@ -31,6 +31,10 @@ OpName %ext_tex_params "ext_tex_params" OpName %t "t" OpName %outImage "outImage" + OpName %tint_clamp "tint_clamp" + OpName %e "e" + OpName %low "low" + OpName %high "high" OpName %gammaCorrection "gammaCorrection" OpName %v "v" OpName %params "params" @@ -40,10 +44,6 @@ OpName %coord "coord" OpName %params_0 "params" OpName %color "color" - OpName %tint_clamp "tint_clamp" - OpName %e "e" - OpName %low "low" - OpName %high "high" OpName %main "main" OpName %red "red" OpName %green "green" @@ -95,20 +95,20 @@ %18 = OpTypeImage %float 2D 0 0 0 2 Rgba8 %_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 %outImage = OpVariable %_ptr_UniformConstant_18 UniformConstant - %19 = OpTypeFunction %v3float %v3float %GammaTransferParams + %int = OpTypeInt 32 1 + %v2int = OpTypeVector %int 2 + %19 = OpTypeFunction %v2int %v2int %v2int %v2int + %30 = OpTypeFunction %v3float %v3float %GammaTransferParams %bool = OpTypeBool %v3bool = OpTypeVector %bool 3 %_ptr_Function_v3float = OpTypePointer Function %v3float - %39 = OpConstantNull %v3float - %int = OpTypeInt 32 1 - %v2int = OpTypeVector %int 2 - %59 = OpTypeFunction %v4float %3 %3 %v2int %ExternalTextureParams + %49 = OpConstantNull %v3float + %69 = OpTypeFunction %v4float %3 %3 %v2int %ExternalTextureParams %uint_1 = OpConstant %uint 1 - %76 = OpConstantNull %int + %84 = OpConstantNull %int %v2float = OpTypeVector %float 2 %float_1 = OpConstant %float 1 - %90 = OpConstantNull %uint - %108 = OpTypeFunction %v2int %v2int %v2int %v2int + %98 = OpConstantNull %uint %void = OpTypeVoid %116 = OpTypeFunction %void %int_10 = OpConstant %int 10 @@ -125,106 +125,106 @@ %int_118 = OpConstant %int 118 %154 = OpConstantComposite %v2int %int_70 %int_118 %int_1 = OpConstant %int 1 - %168 = OpConstantComposite %v2int %int_1 %76 -%gammaCorrection = OpFunction %v3float None %19 + %168 = OpConstantComposite %v2int %int_1 %84 + %tint_clamp = OpFunction %v2int None %19 + %e = OpFunctionParameter %v2int + %low = OpFunctionParameter %v2int + %high = OpFunctionParameter %v2int + %26 = OpLabel + %29 = OpExtInst %v2int %28 SMax %e %low + %27 = OpExtInst %v2int %28 SMin %29 %high + OpReturnValue %27 + OpFunctionEnd +%gammaCorrection = OpFunction %v3float None %30 %v = OpFunctionParameter %v3float %params = OpFunctionParameter %GammaTransferParams - %23 = OpLabel - %37 = OpVariable %_ptr_Function_v3float Function %39 - %49 = OpVariable %_ptr_Function_v3float Function %39 - %55 = OpVariable %_ptr_Function_v3float Function %39 - %24 = OpExtInst %v3float %25 FAbs %v - %26 = OpCompositeExtract %float %params 4 - %27 = OpCompositeConstruct %v3float %26 %26 %26 - %28 = OpFOrdLessThan %v3bool %24 %27 - %31 = OpExtInst %v3float %25 FSign %v - %32 = OpCompositeExtract %float %params 3 - %33 = OpExtInst %v3float %25 FAbs %v - %34 = OpVectorTimesScalar %v3float %33 %32 - %35 = OpCompositeExtract %float %params 6 - %40 = OpCompositeConstruct %v3float %35 %35 %35 - %36 = OpFAdd %v3float %34 %40 - %41 = OpFMul %v3float %31 %36 - %42 = OpExtInst %v3float %25 FSign %v - %44 = OpCompositeExtract %float %params 1 - %45 = OpExtInst %v3float %25 FAbs %v - %46 = OpVectorTimesScalar %v3float %45 %44 - %47 = OpCompositeExtract %float %params 2 - %50 = OpCompositeConstruct %v3float %47 %47 %47 - %48 = OpFAdd %v3float %46 %50 - %51 = OpCompositeExtract %float %params 0 - %52 = OpCompositeConstruct %v3float %51 %51 %51 - %43 = OpExtInst %v3float %25 Pow %48 %52 - %53 = OpCompositeExtract %float %params 5 - %56 = OpCompositeConstruct %v3float %53 %53 %53 - %54 = OpFAdd %v3float %43 %56 - %57 = OpFMul %v3float %42 %54 - %58 = OpSelect %v3float %28 %41 %57 - OpReturnValue %58 + %34 = OpLabel + %47 = OpVariable %_ptr_Function_v3float Function %49 + %59 = OpVariable %_ptr_Function_v3float Function %49 + %65 = OpVariable %_ptr_Function_v3float Function %49 + %35 = OpExtInst %v3float %28 FAbs %v + %36 = OpCompositeExtract %float %params 4 + %37 = OpCompositeConstruct %v3float %36 %36 %36 + %38 = OpFOrdLessThan %v3bool %35 %37 + %41 = OpExtInst %v3float %28 FSign %v + %42 = OpCompositeExtract %float %params 3 + %43 = OpExtInst %v3float %28 FAbs %v + %44 = OpVectorTimesScalar %v3float %43 %42 + %45 = OpCompositeExtract %float %params 6 + %50 = OpCompositeConstruct %v3float %45 %45 %45 + %46 = OpFAdd %v3float %44 %50 + %51 = OpFMul %v3float %41 %46 + %52 = OpExtInst %v3float %28 FSign %v + %54 = OpCompositeExtract %float %params 1 + %55 = OpExtInst %v3float %28 FAbs %v + %56 = OpVectorTimesScalar %v3float %55 %54 + %57 = OpCompositeExtract %float %params 2 + %60 = OpCompositeConstruct %v3float %57 %57 %57 + %58 = OpFAdd %v3float %56 %60 + %61 = OpCompositeExtract %float %params 0 + %62 = OpCompositeConstruct %v3float %61 %61 %61 + %53 = OpExtInst %v3float %28 Pow %58 %62 + %63 = OpCompositeExtract %float %params 5 + %66 = OpCompositeConstruct %v3float %63 %63 %63 + %64 = OpFAdd %v3float %53 %66 + %67 = OpFMul %v3float %52 %64 + %68 = OpSelect %v3float %38 %51 %67 + OpReturnValue %68 OpFunctionEnd -%textureLoadExternal = OpFunction %v4float None %59 +%textureLoadExternal = OpFunction %v4float None %69 %plane0 = OpFunctionParameter %3 %plane1 = OpFunctionParameter %3 %coord = OpFunctionParameter %v2int %params_0 = OpFunctionParameter %ExternalTextureParams - %67 = OpLabel - %color = OpVariable %_ptr_Function_v3float Function %39 - %69 = OpCompositeExtract %uint %params_0 0 - %71 = OpIEqual %bool %69 %uint_1 - OpSelectionMerge %72 None - OpBranchConditional %71 %73 %74 - %73 = OpLabel - %75 = OpImageFetch %v4float %plane0 %coord Lod %76 - %77 = OpVectorShuffle %v3float %75 %75 0 1 2 - OpStore %color %77 - OpBranch %72 - %74 = OpLabel - %78 = OpImageFetch %v4float %plane0 %coord Lod %76 - %79 = OpCompositeExtract %float %78 0 - %80 = OpImageFetch %v4float %plane1 %coord Lod %76 - %82 = OpVectorShuffle %v2float %80 %80 0 1 - %83 = OpCompositeExtract %float %82 0 - %84 = OpCompositeExtract %float %82 1 - %86 = OpCompositeConstruct %v4float %79 %83 %84 %float_1 - %87 = OpCompositeExtract %mat3v4float %params_0 2 - %88 = OpVectorTimesMatrix %v3float %86 %87 - OpStore %color %88 - OpBranch %72 - %72 = OpLabel - %89 = OpCompositeExtract %uint %params_0 1 - %91 = OpIEqual %bool %89 %90 - OpSelectionMerge %92 None - OpBranchConditional %91 %93 %92 - %93 = OpLabel - %95 = OpLoad %v3float %color - %96 = OpCompositeExtract %GammaTransferParams %params_0 3 - %94 = OpFunctionCall %v3float %gammaCorrection %95 %96 - OpStore %color %94 - %97 = OpCompositeExtract %mat3v3float %params_0 5 - %98 = OpLoad %v3float %color - %99 = OpMatrixTimesVector %v3float %97 %98 - OpStore %color %99 - %101 = OpLoad %v3float %color - %102 = OpCompositeExtract %GammaTransferParams %params_0 4 - %100 = OpFunctionCall %v3float %gammaCorrection %101 %102 - OpStore %color %100 - OpBranch %92 - %92 = OpLabel + %75 = OpLabel + %color = OpVariable %_ptr_Function_v3float Function %49 + %77 = OpCompositeExtract %uint %params_0 0 + %79 = OpIEqual %bool %77 %uint_1 + OpSelectionMerge %80 None + OpBranchConditional %79 %81 %82 + %81 = OpLabel + %83 = OpImageFetch %v4float %plane0 %coord Lod %84 + %85 = OpVectorShuffle %v3float %83 %83 0 1 2 + OpStore %color %85 + OpBranch %80 + %82 = OpLabel + %86 = OpImageFetch %v4float %plane0 %coord Lod %84 + %87 = OpCompositeExtract %float %86 0 + %88 = OpImageFetch %v4float %plane1 %coord Lod %84 + %90 = OpVectorShuffle %v2float %88 %88 0 1 + %91 = OpCompositeExtract %float %90 0 + %92 = OpCompositeExtract %float %90 1 + %94 = OpCompositeConstruct %v4float %87 %91 %92 %float_1 + %95 = OpCompositeExtract %mat3v4float %params_0 2 + %96 = OpVectorTimesMatrix %v3float %94 %95 + OpStore %color %96 + OpBranch %80 + %80 = OpLabel + %97 = OpCompositeExtract %uint %params_0 1 + %99 = OpIEqual %bool %97 %98 + OpSelectionMerge %100 None + OpBranchConditional %99 %101 %100 + %101 = OpLabel %103 = OpLoad %v3float %color - %104 = OpCompositeExtract %float %103 0 - %105 = OpCompositeExtract %float %103 1 - %106 = OpCompositeExtract %float %103 2 - %107 = OpCompositeConstruct %v4float %104 %105 %106 %float_1 - OpReturnValue %107 - OpFunctionEnd - %tint_clamp = OpFunction %v2int None %108 - %e = OpFunctionParameter %v2int - %low = OpFunctionParameter %v2int - %high = OpFunctionParameter %v2int - %113 = OpLabel - %115 = OpExtInst %v2int %25 SMax %e %low - %114 = OpExtInst %v2int %25 SMin %115 %high - OpReturnValue %114 + %104 = OpCompositeExtract %GammaTransferParams %params_0 3 + %102 = OpFunctionCall %v3float %gammaCorrection %103 %104 + OpStore %color %102 + %105 = OpCompositeExtract %mat3v3float %params_0 5 + %106 = OpLoad %v3float %color + %107 = OpMatrixTimesVector %v3float %105 %106 + OpStore %color %107 + %109 = OpLoad %v3float %color + %110 = OpCompositeExtract %GammaTransferParams %params_0 4 + %108 = OpFunctionCall %v3float %gammaCorrection %109 %110 + OpStore %color %108 + OpBranch %100 + %100 = OpLabel + %111 = OpLoad %v3float %color + %112 = OpCompositeExtract %float %111 0 + %113 = OpCompositeExtract %float %111 1 + %114 = OpCompositeExtract %float %111 2 + %115 = OpCompositeConstruct %v4float %112 %113 %114 %float_1 + OpReturnValue %115 OpFunctionEnd %main = OpFunction %void None %116 %119 = OpLabel diff --git a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl index c1df2b14ec..27991572a6 100644 --- a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl +++ b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl @@ -1,6 +1,3 @@ -Texture2D arg_0 : register(t0, space1); -SamplerState arg_1 : register(s1, space1); - float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, float2 coord) { int3 tint_tmp; t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z); @@ -10,6 +7,9 @@ float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, fl return t.SampleLevel(s, clamped, 0.0f); } +Texture2D arg_0 : register(t0, space1); +SamplerState arg_1 : register(s1, space1); + void textureSampleBaseClampToEdge_9ca02c() { float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, (0.0f).xx); } diff --git a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl index c1df2b14ec..27991572a6 100644 --- a/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl +++ b/test/tint/builtins/gen/literal/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl @@ -1,6 +1,3 @@ -Texture2D arg_0 : register(t0, space1); -SamplerState arg_1 : register(s1, space1); - float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, float2 coord) { int3 tint_tmp; t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z); @@ -10,6 +7,9 @@ float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, fl return t.SampleLevel(s, clamped, 0.0f); } +Texture2D arg_0 : register(t0, space1); +SamplerState arg_1 : register(s1, space1); + void textureSampleBaseClampToEdge_9ca02c() { float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, (0.0f).xx); } diff --git a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl index 870d5d4c46..1e3e8bdc72 100644 --- a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl +++ b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.dxc.hlsl @@ -1,6 +1,3 @@ -Texture2D arg_0 : register(t0, space1); -SamplerState arg_1 : register(s1, space1); - float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, float2 coord) { int3 tint_tmp; t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z); @@ -10,6 +7,9 @@ float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, fl return t.SampleLevel(s, clamped, 0.0f); } +Texture2D arg_0 : register(t0, space1); +SamplerState arg_1 : register(s1, space1); + void textureSampleBaseClampToEdge_9ca02c() { float2 arg_2 = (0.0f).xx; float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, arg_2); diff --git a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl index 870d5d4c46..1e3e8bdc72 100644 --- a/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl +++ b/test/tint/builtins/gen/var/textureSampleBaseClampToEdge/9ca02c.wgsl.expected.fxc.hlsl @@ -1,6 +1,3 @@ -Texture2D arg_0 : register(t0, space1); -SamplerState arg_1 : register(s1, space1); - float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, float2 coord) { int3 tint_tmp; t.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z); @@ -10,6 +7,9 @@ float4 tint_textureSampleBaseClampToEdge(Texture2D t, SamplerState s, fl return t.SampleLevel(s, clamped, 0.0f); } +Texture2D arg_0 : register(t0, space1); +SamplerState arg_1 : register(s1, space1); + void textureSampleBaseClampToEdge_9ca02c() { float2 arg_2 = (0.0f).xx; float4 res = tint_textureSampleBaseClampToEdge(arg_0, arg_1, arg_2);