[ir] Add base `ir::transform::Transform` class

Enable the transform manager to run a pipeline that mixes AST and IR
transforms, automatically converting the current program as necessary.

Bug: tint:1718
Change-Id: I8df76db61edd94e0b1d7c2aaabc18b394db3d8de
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132502
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: James Price <jrprice@google.com>
This commit is contained in:
James Price 2023-05-17 15:52:05 +00:00 committed by Dawn LUCI CQ
parent db5ad9f357
commit 95b06129f0
9 changed files with 409 additions and 38 deletions

View File

@ -359,10 +359,7 @@ libtint_source_set("libtint_transform_src") {
"transform/transform.cc",
"transform/transform.h",
]
deps = [
":libtint_program_src",
":libtint_utils_src",
]
deps = [ ":libtint_utils_src" ]
}
libtint_source_set("libtint_ast_transform_base_src") {
@ -370,11 +367,11 @@ libtint_source_set("libtint_ast_transform_base_src") {
"ast/transform/transform.cc",
"ast/transform/transform.h",
]
public_deps = [ ":libtint_transform_src" ]
deps = [
":libtint_builtins_src",
":libtint_program_src",
":libtint_sem_src",
":libtint_transform_src",
":libtint_type_src",
":libtint_utils_src",
]
@ -387,7 +384,10 @@ libtint_source_set("libtint_transform_manager_src") {
]
deps = [
":libtint_ast_transform_base_src",
":libtint_ir_builder_src",
":libtint_ir_src",
":libtint_program_src",
":libtint_transform_src",
]
}
@ -492,9 +492,9 @@ libtint_source_set("libtint_ast_transform_src") {
"ast/transform/zero_init_workgroup_memory.cc",
"ast/transform/zero_init_workgroup_memory.h",
]
public_deps = [ ":libtint_ast_transform_base_src" ]
deps = [
":libtint_ast_src",
":libtint_ast_transform_base_src",
":libtint_builtins_src",
":libtint_program_src",
":libtint_sem_src",
@ -1225,6 +1225,8 @@ libtint_source_set("libtint_ir_src") {
"ir/store.h",
"ir/switch.cc",
"ir/switch.h",
"ir/transform/transform.cc",
"ir/transform/transform.h",
"ir/unary.cc",
"ir/unary.h",
"ir/user_call.cc",
@ -1239,6 +1241,7 @@ libtint_source_set("libtint_ir_src") {
":libtint_builtins_src",
":libtint_constant_src",
":libtint_symbols_src",
":libtint_transform_src",
":libtint_type_src",
":libtint_utils_src",
]
@ -1247,6 +1250,7 @@ libtint_source_set("libtint_ir_src") {
source_set("libtint") {
public_deps = [
":libtint_ast_src",
":libtint_ast_transform_base_src",
":libtint_ast_transform_src",
":libtint_constant_src",
":libtint_initializer_src",
@ -1738,6 +1742,7 @@ if (tint_build_unittests) {
]
deps = [
":libtint_ast_transform_base_src",
":libtint_ast_transform_src",
":libtint_builtins_src",
":libtint_transform_manager_src",
@ -2183,12 +2188,23 @@ if (tint_build_unittests) {
"clone_context_test.cc",
"program_builder_test.cc",
"program_test.cc",
"transform/manager_test.cc",
]
deps = [
":libtint_ast_transform_base_src",
":libtint_program_src",
":libtint_transform_manager_src",
":libtint_unittests_ast_helper",
":tint_unittests_ast_src",
]
if (tint_build_ir) {
deps += [
":libtint_ir_builder_src",
":libtint_ir_src",
]
}
}
tint_unittests_source_set("tint_unittests_ir_src") {

View File

@ -768,6 +768,8 @@ if(${TINT_BUILD_IR})
ir/value.h
ir/var.cc
ir/var.h
ir/transform/transform.cc
ir/transform/transform.h
)
endif()
@ -998,6 +1000,7 @@ if(TINT_BUILD_TESTS)
symbol_test.cc
test_main.cc
ast/transform/transform_test.cc
transform/manager_test.cc
type/array_test.cc
type/atomic_test.cc
type/bool_test.cc

View File

@ -0,0 +1,24 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ir/transform/transform.h"
TINT_INSTANTIATE_TYPEINFO(tint::ir::transform::Transform);
namespace tint::ir::transform {
Transform::Transform() = default;
Transform::~Transform() = default;
} // namespace tint::ir::transform

View File

@ -0,0 +1,48 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_IR_TRANSFORM_TRANSFORM_H_
#define SRC_TINT_IR_TRANSFORM_TRANSFORM_H_
#include "src/tint/transform/transform.h"
#include <utility>
#include "src/tint/utils/castable.h"
// Forward declarations
namespace tint::ir {
class Module;
} // namespace tint::ir
namespace tint::ir::transform {
/// Interface for IR Module transforms.
class Transform : public utils::Castable<Transform, tint::transform::Transform> {
public:
/// Constructor
Transform();
/// Destructor
~Transform() override;
/// Run the transform on @p module
/// @param module the source module to transform
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
virtual void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const = 0;
};
} // namespace tint::ir::transform
#endif // SRC_TINT_IR_TRANSFORM_TRANSFORM_H_

View File

@ -17,12 +17,24 @@
#include "src/tint/ast/transform/transform.h"
#include "src/tint/program_builder.h"
#if TINT_BUILD_IR
#include "src/tint/ir/from_program.h"
#include "src/tint/ir/to_program.h"
#include "src/tint/ir/transform/transform.h"
#else
// Declare an ir::Module class so that the transform target variant compiles.
namespace ir {
class Module;
}
#endif // TINT_BUILD_IR
/// If set to 1 then the transform::Manager will dump the WGSL of the program
/// before and after each transform. Helpful for debugging bad output.
#define TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM 0
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#include <iostream>
#include "src/tint/ir/disassembler.h"
#define TINT_IF_PRINT_PROGRAM(x) x
#else // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define TINT_IF_PRINT_PROGRAM(x)
@ -33,62 +45,140 @@ namespace tint::transform {
Manager::Manager() = default;
Manager::~Manager() = default;
Program Manager::Run(const Program* program,
const transform::DataMap& inputs,
transform::DataMap& outputs) const {
template <typename OUTPUT, typename INPUT>
OUTPUT Manager::RunTransforms(INPUT in,
const transform::DataMap& inputs,
transform::DataMap& outputs) const {
static_assert(std::is_same<INPUT, const Program*>() || std::is_same<INPUT, ir::Module*>());
static_assert(std::is_same<OUTPUT, Program>() || std::is_same<OUTPUT, ir::Module*>());
// The current transform target, which could be either AST or IR.
std::variant<const Program*, ir::Module*> target = in;
// A local AST program to hold the result of AST transforms.
Program ast_result;
#if TINT_BUILD_IR
// A local IR module to hold the result of AST->IR conversions.
ir::Module ir_result;
#endif
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
auto print_program = [&](const char* msg, const Transform* transform) {
auto wgsl = Program::printer(program);
auto print_program = [&](const char* msg, const char* name) {
std::cout << "=========================================================" << std::endl;
std::cout << "== " << msg << " " << transform->TypeInfo().name << ":" << std::endl;
std::cout << "== " << msg << " " << name << ":" << std::endl;
std::cout << "=========================================================" << std::endl;
std::cout << wgsl << std::endl;
if (!program->IsValid()) {
std::cout << "-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --" << std::endl;
std::cout << program->Diagnostics().str() << std::endl;
if (std::holds_alternative<const Program*>(target)) {
auto* program = std::get<const Program*>(target);
auto wgsl = Program::printer(program);
std::cout << wgsl << std::endl;
if (!program->IsValid()) {
std::cout << "-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --"
<< std::endl;
std::cout << program->Diagnostics().str() << std::endl;
}
} else if (std::holds_alternative<ir::Module*>(target)) {
#if TINT_BUILD_IR
ir::Disassembler dis(*std::get<ir::Module*>(target));
std::cout << dis.Disassemble();
#endif // TINT_BUILD_IR
}
std::cout << "=========================================================" << std::endl
<< std::endl;
};
#endif
std::optional<Program> output;
// Helper functions to get the current program state as either an AST program or IR module,
// performing a conversion if necessary.
auto get_ast = [&]() {
#if TINT_BUILD_IR
if (std::holds_alternative<ir::Module*>(target)) {
// Convert the IR module to an AST program.
ast_result = ir::ToProgram(*std::get<ir::Module*>(target));
target = &ast_result;
}
#endif // TINT_BUILD_IR
TINT_ASSERT(Transform, std::holds_alternative<const Program*>(target));
return std::get<const Program*>(target);
};
#if TINT_BUILD_IR
auto get_ir = [&]() {
if (std::holds_alternative<const Program*>(target)) {
// Convert the AST program to an IR module.
auto converted = ir::FromProgram(std::get<const Program*>(target));
TINT_ASSERT(Transform, converted);
ir_result = converted.Move();
target = &ir_result;
}
TINT_ASSERT(Transform, std::holds_alternative<ir::Module*>(target));
return std::get<ir::Module*>(target);
};
#endif // TINT_BUILD_IR
TINT_IF_PRINT_PROGRAM(print_program("Input of", this));
TINT_IF_PRINT_PROGRAM(print_program("Input of", "transform manager"));
for (const auto& transform : transforms_) {
if (auto* ast_transform = transform->As<ast::transform::Transform>()) {
if (auto result = ast_transform->Apply(program, inputs, outputs)) {
output.emplace(std::move(result.value()));
program = &output.value();
if (auto result = ast_transform->Apply(get_ast(), inputs, outputs)) {
ast_result = std::move(result.value());
target = &ast_result;
if (!program->IsValid()) {
TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
if (!ast_result.IsValid()) {
TINT_IF_PRINT_PROGRAM(
print_program("Invalid output of", transform->TypeInfo().name));
break;
}
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform->TypeInfo().name));
} else {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name
<< std::endl);
}
#if TINT_BUILD_IR
} else if (auto* ir_transform = transform->As<ir::transform::Transform>()) {
ir_transform->Run(get_ir(), inputs, outputs);
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform->TypeInfo().name));
#endif // TINT_BUILD_IR
} else {
ProgramBuilder b;
TINT_ICE(Transform, b.Diagnostics()) << "unhandled transform type";
return Program(std::move(b));
TINT_ASSERT(Transform, false && "unhandled transform type");
}
}
TINT_IF_PRINT_PROGRAM(print_program("Final output of", this));
TINT_IF_PRINT_PROGRAM(print_program("Final output of", "transform manager"));
if (!output) {
ProgramBuilder b;
CloneContext ctx{&b, program, /* auto_clone_symbols */ true};
ctx.Clone();
output = Program(std::move(b));
if constexpr (std::is_same<OUTPUT, Program>()) {
auto* result = get_ast();
if (result == in) {
// AST transform pipelines are expected to return a clone of the program, so make sure
// the input is cloned at least once even if nothing changed.
ProgramBuilder b;
CloneContext ctx{&b, result, /* auto_clone_symbols */ true};
ctx.Clone();
ast_result = Program(std::move(b));
}
return ast_result;
#if TINT_BUILD_IR
} else if constexpr (std::is_same<OUTPUT, ir::Module*>()) {
auto* result = get_ir();
if (result == &ir_result) {
// IR transform pipelines are expected to mutate the module in place, so move the local
// temporary result to the original input.
*in = std::move(ir_result);
}
return in;
#endif // TINT_BUILD_IR
}
return std::move(output.value());
}
Program Manager::Run(const Program* program,
const transform::DataMap& inputs,
transform::DataMap& outputs) const {
return RunTransforms<Program>(program, inputs, outputs);
}
#if TINT_BUILD_IR
void Manager::Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const {
auto* output = RunTransforms<ir::Module*>(module, inputs, outputs);
TINT_ASSERT(Transform, output == module);
}
#endif // TINT_BUILD_IR
} // namespace tint::transform

View File

@ -19,8 +19,16 @@
#include <utility>
#include <vector>
#include "src/tint/program.h"
#include "src/tint/transform/transform.h"
#if TINT_BUILD_IR
// Forward declarations
namespace tint::ir {
class Module;
} // namespace tint::ir
#endif // TINT_BUILD_IR
namespace tint::transform {
/// A collection of Transforms that act as a single Transform.
@ -54,8 +62,19 @@ class Manager {
/// @returns the transformed program
Program Run(const Program* program, const DataMap& inputs, DataMap& outputs) const;
#if TINT_BUILD_IR
/// Runs the transforms on @p module
/// @param module the module to transform
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(ir::Module* module, const DataMap& inputs, DataMap& outputs) const;
#endif // TINT_BUILD_IR
private:
std::vector<std::unique_ptr<Transform>> transforms_;
template <typename OUTPUT, typename INPUT>
OUTPUT RunTransforms(INPUT in, const DataMap& inputs, DataMap& outputs) const;
};
} // namespace tint::transform

View File

@ -0,0 +1,174 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/manager.h"
#include <string>
#include "gtest/gtest.h"
#include "src/tint/ast/transform/transform.h"
#include "src/tint/program_builder.h"
#if TINT_BUILD_IR
#include "src/tint/ir/builder.h" // nogncheck
#include "src/tint/ir/transform/transform.h" // nogncheck
#endif // TINT_BUILD_IR
namespace tint::transform {
namespace {
using TransformManagerTest = testing::Test;
class AST_NoOp final : public ast::transform::Transform {
ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
return SkipTransform;
}
};
class AST_AddFunction final : public ast::transform::Transform {
ApplyResult Apply(const Program* src, const DataMap&, DataMap&) const override {
ProgramBuilder b;
CloneContext ctx{&b, src};
b.Func(b.Sym("ast_func"), {}, b.ty.void_(), {});
ctx.Clone();
return Program(std::move(b));
}
};
#if TINT_BUILD_IR
class IR_AddFunction final : public ir::transform::Transform {
void Run(ir::Module* mod, const DataMap&, DataMap&) const override {
ir::Builder builder(*mod);
auto* func =
builder.CreateFunction(mod->symbols.New("ir_func"), mod->types.Get<type::Void>());
builder.Branch(func->start_target, func->end_target);
mod->functions.Push(func);
}
};
#endif // TINT_BUILD_IR
Program MakeAST() {
ProgramBuilder b;
b.Func(b.Sym("main"), {}, b.ty.void_(), {});
return Program(std::move(b));
}
#if TINT_BUILD_IR
ir::Module MakeIR() {
ir::Module mod;
ir::Builder builder(mod);
auto* func =
builder.CreateFunction(builder.ir.symbols.New("main"), builder.ir.types.Get<type::Void>());
builder.Branch(func->start_target, func->end_target);
builder.ir.functions.Push(func);
return mod;
}
#endif // TINT_BUILD_IR
// Test that an AST program is always cloned, even if all transforms are skipped.
TEST_F(TransformManagerTest, AST_AlwaysClone) {
Program ast = MakeAST();
transform::Manager manager;
transform::DataMap outputs;
manager.Add<AST_NoOp>();
auto result = manager.Run(&ast, {}, outputs);
EXPECT_TRUE(result.IsValid()) << result.Diagnostics();
EXPECT_NE(result.ID(), ast.ID());
ASSERT_EQ(result.AST().Functions().Length(), 1u);
EXPECT_EQ(result.AST().Functions()[0]->name->symbol.Name(), "main");
}
#if TINT_BUILD_IR
// Test that an IR module is mutated in place.
TEST_F(TransformManagerTest, IR_MutateInPlace) {
ir::Module ir = MakeIR();
transform::Manager manager;
transform::DataMap outputs;
manager.Add<IR_AddFunction>();
manager.Run(&ir, {}, outputs);
ASSERT_EQ(ir.functions.Length(), 2u);
EXPECT_EQ(ir.functions[0]->name.Name(), "main");
EXPECT_EQ(ir.functions[1]->name.Name(), "ir_func");
}
TEST_F(TransformManagerTest, AST_MixedTransforms_AST_Before_IR) {
Program ast = MakeAST();
transform::Manager manager;
transform::DataMap outputs;
manager.Add<AST_AddFunction>();
manager.Add<IR_AddFunction>();
auto result = manager.Run(&ast, {}, outputs);
ASSERT_TRUE(result.IsValid()) << result.Diagnostics();
ASSERT_EQ(result.AST().Functions().Length(), 3u);
EXPECT_EQ(result.AST().Functions()[0]->name->symbol.Name(), "ast_func");
EXPECT_EQ(result.AST().Functions()[1]->name->symbol.Name(), "main");
EXPECT_EQ(result.AST().Functions()[2]->name->symbol.Name(), "ir_func");
}
TEST_F(TransformManagerTest, AST_MixedTransforms_IR_Before_AST) {
Program ast = MakeAST();
transform::Manager manager;
transform::DataMap outputs;
manager.Add<IR_AddFunction>();
manager.Add<AST_AddFunction>();
auto result = manager.Run(&ast, {}, outputs);
ASSERT_TRUE(result.IsValid()) << result.Diagnostics();
ASSERT_EQ(result.AST().Functions().Length(), 3u);
EXPECT_EQ(result.AST().Functions()[0]->name->symbol.Name(), "ast_func");
EXPECT_EQ(result.AST().Functions()[1]->name->symbol.Name(), "main");
EXPECT_EQ(result.AST().Functions()[2]->name->symbol.Name(), "ir_func");
}
TEST_F(TransformManagerTest, IR_MixedTransforms_AST_Before_IR) {
ir::Module ir = MakeIR();
transform::Manager manager;
transform::DataMap outputs;
manager.Add<AST_AddFunction>();
manager.Add<IR_AddFunction>();
manager.Run(&ir, {}, outputs);
ASSERT_EQ(ir.functions.Length(), 3u);
EXPECT_EQ(ir.functions[0]->name.Name(), "ast_func");
EXPECT_EQ(ir.functions[1]->name.Name(), "main");
EXPECT_EQ(ir.functions[2]->name.Name(), "ir_func");
}
TEST_F(TransformManagerTest, IR_MixedTransforms_IR_Before_AST) {
ir::Module ir = MakeIR();
transform::Manager manager;
transform::DataMap outputs;
manager.Add<IR_AddFunction>();
manager.Add<AST_AddFunction>();
manager.Run(&ir, {}, outputs);
ASSERT_EQ(ir.functions.Length(), 3u);
EXPECT_EQ(ir.functions[0]->name.Name(), "ast_func");
EXPECT_EQ(ir.functions[1]->name.Name(), "main");
EXPECT_EQ(ir.functions[2]->name.Name(), "ir_func");
}
#endif // TINT_BUILD_IR
} // namespace
} // namespace tint::transform

View File

@ -14,8 +14,6 @@
#include "src/tint/transform/transform.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);

View File

@ -19,7 +19,6 @@
#include <unordered_map>
#include <utility>
#include "src/tint/program.h"
#include "src/tint/utils/castable.h"
namespace tint::transform {