transform: Make methods const, add ShouldRun()

Transforms are supposed to be immutable, operating on the DataMaps provided for input and output, so make the methods const.

Add a ShouldRun() method which the Manager can use to skip over transforms that do not need to be run.

Change-Id: I320ac964577e94ac988748d8aca85bd43ee8d3b5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77120
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-01-25 20:53:25 +00:00 committed by Tint LUCI CQ
parent c8c0e2ea38
commit 12d54d746e
73 changed files with 238 additions and 101 deletions

View File

@ -52,7 +52,7 @@ Program Parse(const std::vector<uint32_t>& input) {
// If the generated program contains matrices with a custom MatrixStride // If the generated program contains matrices with a custom MatrixStride
// attribute then we need to decompose these into an array of vectors // attribute then we need to decompose these into an array of vectors
if (transform::DecomposeStridedMatrix::ShouldRun(&program)) { if (transform::DecomposeStridedMatrix().ShouldRun(&program)) {
transform::Manager manager; transform::Manager manager;
manager.Add<transform::Unshadow>(); manager.Add<transform::Unshadow>();
manager.Add<transform::SimplifyPointers>(); manager.Add<transform::SimplifyPointers>();

View File

@ -27,7 +27,9 @@ AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) { void AddEmptyEntryPoint::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
for (auto* func : ctx.src->AST().Functions()) { for (auto* func : ctx.src->AST().Functions()) {
if (func->IsEntryPoint()) { if (func->IsEntryPoint()) {
ctx.Clone(); ctx.Clone();

View File

@ -35,7 +35,9 @@ class AddEmptyEntryPoint : public Castable<AddEmptyEntryPoint, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -33,7 +33,9 @@ AddSpirvBlockDecoration::AddSpirvBlockDecoration() = default;
AddSpirvBlockDecoration::~AddSpirvBlockDecoration() = default; AddSpirvBlockDecoration::~AddSpirvBlockDecoration() = default;
void AddSpirvBlockDecoration::Run(CloneContext& ctx, const DataMap&, DataMap&) { void AddSpirvBlockDecoration::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
// Collect the set of structs that are nested in other types. // Collect the set of structs that are nested in other types.

View File

@ -65,7 +65,9 @@ class AddSpirvBlockDecoration
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -95,7 +95,7 @@ static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
void ArrayLengthFromUniform::Run(CloneContext& ctx, void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap& outputs) { DataMap& outputs) const {
if (!Requires<SimplifyPointers>(ctx)) { if (!Requires<SimplifyPointers>(ctx)) {
return; return;
} }

View File

@ -103,7 +103,9 @@ class ArrayLengthFromUniform
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -42,7 +42,9 @@ BindingRemapper::Remappings::~Remappings() = default;
BindingRemapper::BindingRemapper() = default; BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default; BindingRemapper::~BindingRemapper() = default;
void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { void BindingRemapper::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* remappings = inputs.Get<Remappings>(); auto* remappings = inputs.Get<Remappings>();
if (!remappings) { if (!remappings) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -75,7 +75,9 @@ class BindingRemapper : public Castable<BindingRemapper, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -71,7 +71,9 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default;
void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) { void CalculateArrayLength::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
if (!Requires<SimplifyPointers>(ctx)) { if (!Requires<SimplifyPointers>(ctx)) {
return; return;

View File

@ -63,7 +63,9 @@ class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -550,7 +550,7 @@ struct CanonicalizeEntryPointIO::State {
void CanonicalizeEntryPointIO::Run(CloneContext& ctx, void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) { DataMap&) const {
if (!Requires<Unshadow>(ctx)) { if (!Requires<Unshadow>(ctx)) {
return; return;
} }

View File

@ -131,7 +131,9 @@ class CanonicalizeEntryPointIO
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
struct State; struct State;
}; };

View File

@ -304,7 +304,9 @@ CombineSamplers::CombineSamplers() = default;
CombineSamplers::~CombineSamplers() = default; CombineSamplers::~CombineSamplers() = default;
void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { void CombineSamplers::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* binding_info = inputs.Get<BindingInfo>(); auto* binding_info = inputs.Get<BindingInfo>();
if (!binding_info) { if (!binding_info) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -95,7 +95,9 @@ class CombineSamplers : public Castable<CombineSamplers, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -790,7 +790,9 @@ const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
DecomposeMemoryAccess::DecomposeMemoryAccess() = default; DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) { void DecomposeMemoryAccess::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
State state(ctx); State state(ctx);

View File

@ -112,7 +112,9 @@ class DecomposeMemoryAccess
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
struct State; struct State;
}; };

View File

@ -107,7 +107,7 @@ DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program) { bool DecomposeStridedMatrix::ShouldRun(const Program* program) const {
bool should_run = false; bool should_run = false;
GatherCustomStrideMatrixMembers( GatherCustomStrideMatrixMembers(
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) { program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
@ -117,7 +117,9 @@ bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
return should_run; return should_run;
} }
void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) { void DecomposeStridedMatrix::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
if (!Requires<SimplifyPointers>(ctx)) { if (!Requires<SimplifyPointers>(ctx)) {
return; return;
} }

View File

@ -36,7 +36,7 @@ class DecomposeStridedMatrix
/// @param program the program to inspect /// @param program the program to inspect
/// @returns true if this transform should be run for the given program /// @returns true if this transform should be run for the given program
static bool ShouldRun(const Program* program); bool ShouldRun(const Program* program) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
@ -45,7 +45,9 @@ class DecomposeStridedMatrix
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -41,11 +41,45 @@ TEST_F(DecomposeStridedMatrixTest, Empty) {
} }
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) { TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
auto* src = R"()"; // struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::SimplifyPointers but the dependency was not run)"; R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::SimplifyPointers but the dependency was not run)";
auto got = Run<DecomposeStridedMatrix>(src); auto got = Run<DecomposeStridedMatrix>(Program(std::move(b)));
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }

View File

@ -28,7 +28,7 @@ ExternalTextureTransform::~ExternalTextureTransform() = default;
void ExternalTextureTransform::Run(CloneContext& ctx, void ExternalTextureTransform::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
// Within this transform, usages of texture_external are replaced with a // Within this transform, usages of texture_external are replaced with a

View File

@ -42,7 +42,9 @@ class ExternalTextureTransform
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -59,7 +59,7 @@ FirstIndexOffset::~FirstIndexOffset() = default;
void FirstIndexOffset::Run(CloneContext& ctx, void FirstIndexOffset::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap& outputs) { DataMap& outputs) const {
// Get the uniform buffer binding point // Get the uniform buffer binding point
uint32_t ub_binding = binding_; uint32_t ub_binding = binding_;
uint32_t ub_group = group_; uint32_t ub_group = group_;

View File

@ -122,7 +122,9 @@ class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
uint32_t binding_ = 0; uint32_t binding_ = 0;

View File

@ -33,7 +33,7 @@ FoldConstants::FoldConstants() = default;
FoldConstants::~FoldConstants() = default; FoldConstants::~FoldConstants() = default;
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) { void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
auto* call = ctx.src->Sem().Get<sem::Call>(expr); auto* call = ctx.src->Sem().Get<sem::Call>(expr);
if (!call) { if (!call) {

View File

@ -36,7 +36,9 @@ class FoldConstants : public Castable<FoldConstants, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -51,7 +51,7 @@ FoldTrivialSingleUseLets::~FoldTrivialSingleUseLets() = default;
void FoldTrivialSingleUseLets::Run(CloneContext& ctx, void FoldTrivialSingleUseLets::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) const {
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* block = node->As<ast::BlockStatement>()) { if (auto* block = node->As<ast::BlockStatement>()) {
auto& stmts = block->statements; auto& stmts = block->statements;

View File

@ -50,7 +50,9 @@ class FoldTrivialSingleUseLets
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -25,7 +25,7 @@ ForLoopToLoop::ForLoopToLoop() = default;
ForLoopToLoop::~ForLoopToLoop() = default; ForLoopToLoop::~ForLoopToLoop() = default;
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) { void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll( ctx.ReplaceAll(
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { [&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
ast::StatementList stmts; ast::StatementList stmts;

View File

@ -37,7 +37,9 @@ class ForLoopToLoop : public Castable<ForLoopToLoop, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -45,7 +45,7 @@ namespace transform {
Glsl::Glsl() = default; Glsl::Glsl() = default;
Glsl::~Glsl() = default; Glsl::~Glsl() = default;
Output Glsl::Run(const Program* in, const DataMap& inputs) { Output Glsl::Run(const Program* in, const DataMap& inputs) const {
Manager manager; Manager manager;
DataMap data; DataMap data;

View File

@ -61,7 +61,7 @@ class Glsl : public Castable<Glsl, Transform> {
/// @param program the source program to transform /// @param program the source program to transform
/// @param data optional extra transform-specific data /// @param data optional extra transform-specific data
/// @returns the transformation result /// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -215,7 +215,7 @@ LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
void LocalizeStructArrayAssignment::Run(CloneContext& ctx, void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) const {
if (!Requires<SimplifyPointers>(ctx)) { if (!Requires<SimplifyPointers>(ctx)) {
return; return;
} }

View File

@ -41,7 +41,9 @@ class LocalizeStructArrayAssignment
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
class State; class State;

View File

@ -56,7 +56,7 @@ LoopToForLoop::LoopToForLoop() = default;
LoopToForLoop::~LoopToForLoop() = default; LoopToForLoop::~LoopToForLoop() = default;
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) { void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* { ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
// For loop condition is taken from the first statement in the loop. // For loop condition is taken from the first statement in the loop.
// This requires an if-statement with either: // This requires an if-statement with either:

View File

@ -37,7 +37,9 @@ class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -32,7 +32,7 @@ namespace transform {
Manager::Manager() = default; Manager::Manager() = default;
Manager::~Manager() = default; Manager::~Manager() = default;
Output Manager::Run(const Program* program, const DataMap& data) { Output Manager::Run(const Program* program, const DataMap& data) const {
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM #if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
auto print_program = [&](const char* msg, const Transform* transform) { auto print_program = [&](const char* msg, const Transform* transform) {
auto wgsl = Program::printer(program); auto wgsl = Program::printer(program);
@ -49,16 +49,22 @@ Output Manager::Run(const Program* program, const DataMap& data) {
}; };
#endif #endif
const Program* in = program;
Output out; Output out;
if (!transforms_.empty()) {
for (const auto& transform : transforms_) { for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
<< transform->TypeInfo().name);
continue;
}
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get())); TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
auto res = transform->Run(program, data); auto res = transform->Run(in, data);
out.program = std::move(res.program); out.program = std::move(res.program);
out.data.Add(std::move(res.data)); out.data.Add(std::move(res.data));
program = &out.program; in = &out.program;
if (!program->IsValid()) { if (!in->IsValid()) {
TINT_IF_PRINT_PROGRAM( TINT_IF_PRINT_PROGRAM(
print_program("Invalid output of", transform.get())); print_program("Invalid output of", transform.get()));
return out; return out;
@ -68,7 +74,8 @@ Output Manager::Run(const Program* program, const DataMap& data) {
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get())); TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
} }
} }
} else {
if (program == in) {
out.program = program->Clone(); out.program = program->Clone();
} }

View File

@ -52,7 +52,7 @@ class Manager : public Castable<Manager, Transform> {
/// @param program the source program to transform /// @param program the source program to transform
/// @param data optional extra transform-specific input data /// @param data optional extra transform-specific input data
/// @returns the transformed program and diagnostics /// @returns the transformed program and diagnostics
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) const override;
private: private:
std::vector<std::unique_ptr<Transform>> transforms_; std::vector<std::unique_ptr<Transform>> transforms_;

View File

@ -379,7 +379,7 @@ ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) const {
State state{ctx}; State state{ctx};
state.Process(); state.Process();
ctx.Clone(); ctx.Clone();

View File

@ -77,7 +77,9 @@ class ModuleScopeVarToEntryPointParam
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
struct State; struct State;
}; };

View File

@ -413,10 +413,10 @@ MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
// parameters. Calls to textureLoad or textureSampleLevel that contain a // parameters. Calls to textureLoad or textureSampleLevel that contain a
// texture_external parameter will be transformed into a newly generated version // texture_external parameter will be transformed into a newly generated version
// of the function, which can perform the desired operation on a single RGBA // of the function, which can perform the desired operation on a single RGBA
// plane or on seperate Y and UV planes. // plane or on separate Y and UV planes.
void MultiplanarExternalTexture::Run(CloneContext& ctx, void MultiplanarExternalTexture::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) { DataMap&) const {
auto* new_binding_points = inputs.Get<NewBindingPoints>(); auto* new_binding_points = inputs.Get<NewBindingPoints>();
if (!new_binding_points) { if (!new_binding_points) {

View File

@ -85,7 +85,9 @@ class MultiplanarExternalTexture
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -54,7 +54,7 @@ NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
void NumWorkgroupsFromUniform::Run(CloneContext& ctx, void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) { DataMap&) const {
if (!Requires<CanonicalizeEntryPointIO>(ctx)) { if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
return; return;
} }

View File

@ -70,7 +70,9 @@ class NumWorkgroupsFromUniform
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -97,7 +97,7 @@ PadArrayElements::PadArrayElements() = default;
PadArrayElements::~PadArrayElements() = default; PadArrayElements::~PadArrayElements() = default;
void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) { void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays; std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays;

View File

@ -45,7 +45,9 @@ class PadArrayElements : public Castable<PadArrayElements, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -398,7 +398,7 @@ PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default;
void PromoteSideEffectsToDecl::Run(CloneContext& ctx, void PromoteSideEffectsToDecl::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) { DataMap&) const {
auto* cfg = inputs.Get<Config>(); auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -61,7 +61,9 @@ class PromoteSideEffectsToDecl
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
class State; class State;

View File

@ -68,7 +68,7 @@ RemovePhonies::RemovePhonies() = default;
RemovePhonies::~RemovePhonies() = default; RemovePhonies::~RemovePhonies() = default;
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) { void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks; std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks;

View File

@ -41,7 +41,9 @@ class RemovePhonies : public Castable<RemovePhonies, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -39,7 +39,7 @@ RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
void RemoveUnreachableStatements::Run(CloneContext& ctx, void RemoveUnreachableStatements::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) const {
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) { if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) { if (!stmt->IsReachable()) {

View File

@ -41,7 +41,9 @@ class RemoveUnreachableStatements
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -1252,7 +1252,7 @@ Renamer::Config::~Config() = default;
Renamer::Renamer() = default; Renamer::Renamer() = default;
Renamer::~Renamer() = default; Renamer::~Renamer() = default;
Output Renamer::Run(const Program* in, const DataMap& inputs) { Output Renamer::Run(const Program* in, const DataMap& inputs) const {
ProgramBuilder out; ProgramBuilder out;
// Disable auto-cloning of symbols, since we want to rename them. // Disable auto-cloning of symbols, since we want to rename them.
CloneContext ctx(&out, in, false); CloneContext ctx(&out, in, false);

View File

@ -85,7 +85,7 @@ class Renamer : public Castable<Renamer, Transform> {
/// @param program the source program to transform /// @param program the source program to transform
/// @param data optional extra transform-specific input data /// @param data optional extra transform-specific input data
/// @returns the transformation result /// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -293,7 +293,7 @@ Robustness::Config& Robustness::Config::operator=(const Config&) = default;
Robustness::Robustness() = default; Robustness::Robustness() = default;
Robustness::~Robustness() = default; Robustness::~Robustness() = default;
void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Config cfg; Config cfg;
if (auto* cfg_data = inputs.Get<Config>()) { if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data; cfg = *cfg_data;

View File

@ -74,7 +74,9 @@ class Robustness : public Castable<Robustness, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
struct State; struct State;

View File

@ -231,7 +231,7 @@ SimplifyPointers::SimplifyPointers() = default;
SimplifyPointers::~SimplifyPointers() = default; SimplifyPointers::~SimplifyPointers() = default;
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) { void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
if (!Requires<Unshadow>(ctx)) { if (!Requires<Unshadow>(ctx)) {
return; return;
} }

View File

@ -46,7 +46,9 @@ class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -31,7 +31,9 @@ SingleEntryPoint::SingleEntryPoint() = default;
SingleEntryPoint::~SingleEntryPoint() = default; SingleEntryPoint::~SingleEntryPoint() = default;
void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { void SingleEntryPoint::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* cfg = inputs.Get<Config>(); auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -61,7 +61,9 @@ class SingleEntryPoint : public Castable<SingleEntryPoint, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -46,7 +46,8 @@ Output::Output(Program&& p) : program(std::move(p)) {}
Transform::Transform() = default; Transform::Transform() = default;
Transform::~Transform() = default; Transform::~Transform() = default;
Output Transform::Run(const Program* program, const DataMap& data /* = {} */) { Output Transform::Run(const Program* program,
const DataMap& data /* = {} */) const {
ProgramBuilder builder; ProgramBuilder builder;
CloneContext ctx(&builder, program); CloneContext ctx(&builder, program);
Output output; Output output;
@ -56,13 +57,18 @@ Output Transform::Run(const Program* program, const DataMap& data /* = {} */) {
return output; return output;
} }
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) { void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics()) TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
<< "Transform::Run() unimplemented for " << TypeInfo().name; << "Transform::Run() unimplemented for " << TypeInfo().name;
} }
bool Transform::Requires(CloneContext& ctx, bool Transform::ShouldRun(const Program*) const {
std::initializer_list<const ::tint::TypeInfo*> deps) { return true;
}
bool Transform::Requires(
CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps) const {
for (auto* dep : deps) { for (auto* dep : deps) {
if (!ctx.src->HasTransformApplied(dep)) { if (!ctx.src->HasTransformApplied(dep)) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -157,7 +157,11 @@ class Transform : public Castable<Transform> {
/// @param program the source program to transform /// @param program the source program to transform
/// @param data optional extra transform-specific input data /// @param data optional extra transform-specific input data
/// @returns the transformation result /// @returns the transformation result
virtual Output Run(const Program* program, const DataMap& data = {}); virtual Output Run(const Program* program, const DataMap& data = {}) const;
/// @param program the program to inspect
/// @returns true if this transform should be run for the given program
virtual bool ShouldRun(const Program* program) const;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
@ -166,14 +170,16 @@ class Transform : public Castable<Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs); virtual void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const;
/// Requires appends an error diagnostic to `ctx.dst` if the template type /// Requires appends an error diagnostic to `ctx.dst` if the template type
/// transforms were not already run on `ctx.src`. /// transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext /// @param ctx the CloneContext
/// @returns true if all dependency transforms have been run /// @returns true if all dependency transforms have been run
template <typename... TRANSFORMS> template <typename... TRANSFORMS>
bool Requires(CloneContext& ctx) { bool Requires(CloneContext& ctx) const {
return Requires(ctx, {&::tint::TypeInfo::Of<TRANSFORMS>()...}); return Requires(ctx, {&::tint::TypeInfo::Of<TRANSFORMS>()...});
} }
@ -183,7 +189,7 @@ class Transform : public Castable<Transform> {
/// @param deps the list of Transform TypeInfos /// @param deps the list of Transform TypeInfos
/// @returns true if all dependency transforms have been run /// @returns true if all dependency transforms have been run
bool Requires(CloneContext& ctx, bool Requires(CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps); std::initializer_list<const ::tint::TypeInfo*> deps) const;
/// Removes the statement `stmt` from the transformed program. /// Removes the statement `stmt` from the transformed program.
/// RemoveStatement handles edge cases, like statements in the initializer and /// RemoveStatement handles edge cases, like statements in the initializer and

View File

@ -24,7 +24,7 @@ namespace {
// Inherit from Transform so we have access to protected methods // Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform { struct CreateASTTypeForTest : public testing::Test, public Transform {
Output Run(const Program*, const DataMap&) override { return {}; } Output Run(const Program*, const DataMap&) const override { return {}; }
const ast::Type* create( const ast::Type* create(
std::function<sem::Type*(ProgramBuilder&)> create_sem_type) { std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {

View File

@ -91,7 +91,7 @@ Unshadow::Unshadow() = default;
Unshadow::~Unshadow() = default; Unshadow::~Unshadow() = default;
void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) { void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run(); State(ctx).Run();
} }

View File

@ -39,7 +39,9 @@ class Unshadow : public Castable<Unshadow, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -34,7 +34,7 @@ VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) { DataMap&) const {
ctx.ReplaceAll( ctx.ReplaceAll(
[&](const ast::CallExpression* expr) -> const ast::CallExpression* { [&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* call = ctx.src->Sem().Get(expr); auto* call = ctx.src->Sem().Get(expr);

View File

@ -37,7 +37,9 @@ class VectorizeScalarMatrixConstructors
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
}; };
} // namespace transform } // namespace transform

View File

@ -903,7 +903,9 @@ struct State {
VertexPulling::VertexPulling() = default; VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default; VertexPulling::~VertexPulling() = default;
void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { void VertexPulling::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto cfg = cfg_; auto cfg = cfg_;
if (auto* cfg_data = inputs.Get<Config>()) { if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data; cfg = *cfg_data;

View File

@ -172,7 +172,9 @@ class VertexPulling : public Castable<VertexPulling, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
Config cfg_; Config cfg_;

View File

@ -38,7 +38,9 @@ WrapArraysInStructs::WrapArraysInStructs() = default;
WrapArraysInStructs::~WrapArraysInStructs() = default; WrapArraysInStructs::~WrapArraysInStructs() = default;
void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) { void WrapArraysInStructs::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays; std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays;

View File

@ -51,7 +51,9 @@ class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
struct WrappedArrayInfo { struct WrappedArrayInfo {

View File

@ -433,7 +433,9 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) { void ZeroInitWorkgroupMemory::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
for (auto* fn : ctx.src->AST().Functions()) { for (auto* fn : ctx.src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kCompute) { if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
State{ctx}.Run(fn); State{ctx}.Run(fn);

View File

@ -39,7 +39,9 @@ class ZeroInitWorkgroupMemory
/// ProgramBuilder /// ProgramBuilder
/// @param inputs optional extra transform-specific input data /// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data /// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private: private:
struct State; struct State;