diff --git a/src/writer/spirv/test_helper.h b/src/writer/spirv/test_helper.h index 241a50f817..94dc745ef9 100644 --- a/src/writer/spirv/test_helper.h +++ b/src/writer/spirv/test_helper.h @@ -19,6 +19,7 @@ #include #include "gtest/gtest.h" +#include "src/ast/builder.h" #include "src/ast/module.h" #include "src/context.h" #include "src/type_determiner.h" @@ -30,28 +31,24 @@ namespace spirv { /// Helper class for testing template -class TestHelperBase : public BASE { +class TestHelperBase : public ast::BuilderWithContext, public BASE { public: - TestHelperBase() : td(&ctx, &mod), b(&ctx, &mod) {} + TestHelperBase() : td(ctx, &mod), b(ctx, &mod) {} ~TestHelperBase() = default; - /// Creates a new `ast::Node` owned by the Context. When the Context is - /// destructed, the `ast::Node` will also be destructed. - /// @param args the arguments to pass to the type constructor - /// @returns the node pointer - template - T* create(ARGS&&... args) { - return ctx.create(std::forward(args)...); - } - - /// The context - Context ctx; /// The module ast::Module mod; /// The type determiner TypeDeterminer td; /// The generator - Builder b; + spirv::Builder b; + + protected: + /// Called whenever a new variable is built with `Var()`. + /// @param var the variable that was built + void OnVariableBuilt(ast::Variable* var) override { + td.RegisterVariableForTesting(var); + } }; using TestHelper = TestHelperBase;