diff --git a/samples/main.cc b/samples/main.cc index e44fbfa53c..dd7970f7f8 100644 --- a/samples/main.cc +++ b/samples/main.cc @@ -773,6 +773,16 @@ int main(int argc, const char** argv) { tint::SetInternalCompilerErrorReporter(&TintInternalCompilerErrorReporter); +#if TINT_BUILD_WGSL_WRITER + tint::Program::printer = [](const tint::Program* program) { + auto result = tint::writer::wgsl::Generate(program, {}); + if (!result.error.empty()) { + return "error: " + result.error; + } + return result.wgsl; + }; +#endif // TINT_BUILD_WGSL_WRITER + if (!ParseArgs(args, &options)) { std::cerr << "Failed to parse arguments." << std::endl; return 1; diff --git a/src/program.cc b/src/program.cc index 24d2a1f04b..5e72982a47 100644 --- a/src/program.cc +++ b/src/program.cc @@ -21,6 +21,15 @@ #include "src/sem/expression.h" namespace tint { +namespace { + +std::string DefaultPrinter(const Program*) { + return ""; +} + +} // namespace + +Program::Printer Program::printer = DefaultPrinter; Program::Program() = default; diff --git a/src/program.h b/src/program.h index 32534faab5..d2ffe11123 100644 --- a/src/program.h +++ b/src/program.h @@ -186,6 +186,12 @@ class Program { /// @returns a string representation of the node std::string str(const ast::Node* node) const; + /// A function that can be used to print a program + using Printer = std::string (*)(const Program*); + + /// The Program printer used for testing and debugging. + static Printer printer; + private: Program(const Program&) = delete; diff --git a/src/test_main.cc b/src/test_main.cc index c01a825036..9b9823abbc 100644 --- a/src/test_main.cc +++ b/src/test_main.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "gmock/gmock.h" +#include "src/program.h" #include "src/reader/spirv/parser_impl_test_helper.h" -#include "src/writer/hlsl/test_helper.h" -#include "src/writer/msl/test_helper.h" +#include "src/writer/wgsl/generator.h" namespace { @@ -48,6 +48,16 @@ struct Flags { int main(int argc, char** argv) { testing::InitGoogleMock(&argc, argv); +#if TINT_BUILD_WGSL_WRITER + tint::Program::printer = [](const tint::Program* program) { + auto result = tint::writer::wgsl::Generate(program, {}); + if (!result.error.empty()) { + return "error: " + result.error; + } + return result.wgsl; + }; +#endif // TINT_BUILD_WGSL_WRITER + Flags flags; if (!flags.parse(argc, argv)) { return -1; diff --git a/src/transform/manager.cc b/src/transform/manager.cc index 0c1a0f030c..c974fbc8af 100644 --- a/src/transform/manager.cc +++ b/src/transform/manager.cc @@ -14,6 +14,16 @@ #include "src/transform/manager.h" +/// If set to 1 then the transform::Manager will dump the WGSL of the program +/// before and after each transform. Helpful for debugging bad output. +#define PRINT_PROGRAM_FOR_EACH_TRANSFORM 0 + +#if PRINT_PROGRAM_FOR_EACH_TRANSFORM +#define IF_PRINT_PROGRAM(x) x +#else // PRINT_PROGRAM_FOR_EACH_TRANSFORM +#define IF_PRINT_PROGRAM(x) +#endif // PRINT_PROGRAM_FOR_EACH_TRANSFORM + TINT_INSTANTIATE_TYPEINFO(tint::transform::Manager); namespace tint { @@ -23,16 +33,39 @@ Manager::Manager() = default; Manager::~Manager() = default; Output Manager::Run(const Program* program, const DataMap& data) { +#if PRINT_PROGRAM_FOR_EACH_TRANSFORM + auto print_program = [&](const char* msg, const Transform* transform) { + auto wgsl = Program::printer(program); + std::cout << "---------------------------------------------------------" + << std::endl; + std::cout << "-- " << msg << " " << transform->TypeInfo().name << ":" + << std::endl; + std::cout << "---------------------------------------------------------" + << std::endl; + std::cout << wgsl << std::endl; + std::cout << "---------------------------------------------------------" + << std::endl + << std::endl; + }; +#endif + Output out; if (!transforms_.empty()) { - for (auto& transform : transforms_) { + for (const auto& transform : transforms_) { + IF_PRINT_PROGRAM(print_program("Input to", transform.get())); + auto res = transform->Run(program, data); out.program = std::move(res.program); out.data.Add(std::move(res.data)); - if (!out.program.IsValid()) { + program = &out.program; + if (!program->IsValid()) { + IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get())); return out; } - program = &out.program; + + if (transform == transforms_.back()) { + IF_PRINT_PROGRAM(print_program("Output of", transform.get())); + } } } else { out.program = program->Clone();