diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index d59d21f106..c999d85eeb 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -68,17 +68,25 @@ class FakeExpr : public ast::Expression { void to_str(std::ostream&, size_t) const override {} }; -class TypeDeterminerTest : public testing::Test { +class TypeDeterminerHelper { public: - void SetUp() override { td_ = std::make_unique(&ctx_); } + TypeDeterminerHelper() : td_(std::make_unique(&ctx_)) {} TypeDeterminer* td() const { return td_.get(); } + ast::Module* mod() { return &mod_; } private: Context ctx_; + ast::Module mod_; std::unique_ptr td_; }; +class TypeDeterminerTest : public TypeDeterminerHelper, public testing::Test {}; + +template +class TypeDeterminerTestWithParam : public TypeDeterminerHelper, + public testing::TestWithParam {}; + TEST_F(TypeDeterminerTest, Error_WithEmptySource) { FakeStmt s; s.set_source(Source{0, 0}); @@ -422,14 +430,12 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Array) { auto idx = std::make_unique( std::make_unique(&i32, 2)); - - ast::Module m; auto var = std::make_unique("my_var", ast::StorageClass::kNone, &ary); - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::ArrayAccessorExpression acc( std::make_unique("my_var"), std::move(idx)); @@ -445,14 +451,12 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix) { auto idx = std::make_unique( std::make_unique(&i32, 2)); - - ast::Module m; auto var = std::make_unique("my_var", ast::StorageClass::kNone, &mat); - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::ArrayAccessorExpression acc( std::make_unique("my_var"), std::move(idx)); @@ -471,14 +475,12 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Matrix_BothDimensions) { std::make_unique(&i32, 2)); auto idx2 = std::make_unique( std::make_unique(&i32, 1)); - - ast::Module m; auto var = std::make_unique("my_var", ast::StorageClass::kNone, &mat); - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::ArrayAccessorExpression acc( std::make_unique( @@ -498,14 +500,12 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) { auto idx = std::make_unique( std::make_unique(&i32, 2)); - - ast::Module m; auto var = std::make_unique("my_var", ast::StorageClass::kNone, &vec); - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::ArrayAccessorExpression acc( std::make_unique("my_var"), std::move(idx)); @@ -530,11 +530,10 @@ TEST_F(TypeDeterminerTest, Expr_Call) { ast::VariableList params; auto func = std::make_unique("my_func", std::move(params), &f32); - ast::Module m; - m.AddFunction(std::move(func)); + mod()->AddFunction(std::move(func)); // Register the function - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::ExpressionList call_params; ast::CallExpression call( @@ -551,11 +550,10 @@ TEST_F(TypeDeterminerTest, Expr_Call_WithParams) { ast::VariableList params; auto func = std::make_unique("my_func", std::move(params), &f32); - ast::Module m; - m.AddFunction(std::move(func)); + mod()->AddFunction(std::move(func)); // Register the function - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::ExpressionList call_params; call_params.push_back(std::make_unique( @@ -614,14 +612,12 @@ TEST_F(TypeDeterminerTest, Expr_Constructor_Type) { TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) { ast::type::F32Type f32; - - ast::Module m; auto var = std::make_unique("my_var", ast::StorageClass::kNone, &f32); - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::IdentifierExpression ident("my_var"); EXPECT_TRUE(td()->DetermineResultType(&ident)); @@ -659,11 +655,10 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { ast::VariableList params; auto func = std::make_unique("my_func", std::move(params), &f32); - ast::Module m; - m.AddFunction(std::move(func)); + mod()->AddFunction(std::move(func)); // Register the function - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); ast::IdentifierExpression ident("my_func"); EXPECT_TRUE(td()->DetermineResultType(&ident)); @@ -690,11 +685,10 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) { auto var = std::make_unique("my_struct", ast::StorageClass::kNone, &st); - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); auto ident = std::make_unique("my_struct"); auto mem_ident = std::make_unique("second_member"); @@ -711,11 +705,10 @@ TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { auto var = std::make_unique("my_vec", ast::StorageClass::kNone, &vec3); - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); auto ident = std::make_unique("my_vec"); auto swizzle = std::make_unique("xy"); @@ -780,12 +773,10 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) { auto var = std::make_unique("c", ast::StorageClass::kNone, &stA); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td()->Determine(&m)); + EXPECT_TRUE(td()->Determine(mod())); auto ident = std::make_unique("c"); auto mem_ident = std::make_unique("mem"); @@ -809,7 +800,7 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) { EXPECT_EQ(mem.result_type()->AsVector()->size(), 2u); } -using Expr_Binary_BitwiseTest = testing::TestWithParam; +using Expr_Binary_BitwiseTest = TypeDeterminerTestWithParam; TEST_P(Expr_Binary_BitwiseTest, Scalar) { auto op = GetParam(); @@ -817,21 +808,16 @@ TEST_P(Expr_Binary_BitwiseTest, Scalar) { auto var = std::make_unique("val", ast::StorageClass::kNone, &i32); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine(mod())) << td()->error(); ast::BinaryExpression expr( op, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); EXPECT_TRUE(expr.result_type()->IsI32()); } @@ -844,21 +830,16 @@ TEST_P(Expr_Binary_BitwiseTest, Vector) { auto var = std::make_unique("val", ast::StorageClass::kNone, &vec3); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine(mod())) << td()->error(); ast::BinaryExpression expr( op, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsI32()); @@ -877,7 +858,7 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, ast::BinaryOp::kDivide, ast::BinaryOp::kModulo)); -using Expr_Binary_LogicalTest = testing::TestWithParam; +using Expr_Binary_LogicalTest = TypeDeterminerTestWithParam; TEST_P(Expr_Binary_LogicalTest, Scalar) { auto op = GetParam(); @@ -885,21 +866,16 @@ TEST_P(Expr_Binary_LogicalTest, Scalar) { auto var = std::make_unique("val", ast::StorageClass::kNone, &bool_type); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( op, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); EXPECT_TRUE(expr.result_type()->IsBool()); } @@ -912,21 +888,16 @@ TEST_P(Expr_Binary_LogicalTest, Vector) { auto var = std::make_unique("val", ast::StorageClass::kNone, &vec3); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine(mod())) << td()->error(); ast::BinaryExpression expr( op, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool()); @@ -937,7 +908,7 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, testing::Values(ast::BinaryOp::kLogicalAnd, ast::BinaryOp::kLogicalOr)); -using Expr_Binary_CompareTest = testing::TestWithParam; +using Expr_Binary_CompareTest = TypeDeterminerTestWithParam; TEST_P(Expr_Binary_CompareTest, Scalar) { auto op = GetParam(); @@ -945,21 +916,16 @@ TEST_P(Expr_Binary_CompareTest, Scalar) { auto var = std::make_unique("val", ast::StorageClass::kNone, &i32); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( op, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); EXPECT_TRUE(expr.result_type()->IsBool()); } @@ -972,21 +938,16 @@ TEST_P(Expr_Binary_CompareTest, Vector) { auto var = std::make_unique("val", ast::StorageClass::kNone, &vec3); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( op, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool()); @@ -1006,22 +967,17 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Scalar_Scalar) { auto var = std::make_unique("val", ast::StorageClass::kNone, &i32); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("val"), std::make_unique("val")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); EXPECT_TRUE(expr.result_type()->IsI32()); } @@ -1034,23 +990,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Vector_Scalar) { std::make_unique("scalar", ast::StorageClass::kNone, &f32); auto vector = std::make_unique( "vector", ast::StorageClass::kNone, &vec3); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(scalar)); - m.AddGlobalVariable(std::move(vector)); + mod()->AddGlobalVariable(std::move(scalar)); + mod()->AddGlobalVariable(std::move(vector)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("vector"), std::make_unique("scalar")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32()); @@ -1065,23 +1016,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Scalar_Vector) { std::make_unique("scalar", ast::StorageClass::kNone, &f32); auto vector = std::make_unique( "vector", ast::StorageClass::kNone, &vec3); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(scalar)); - m.AddGlobalVariable(std::move(vector)); + mod()->AddGlobalVariable(std::move(scalar)); + mod()->AddGlobalVariable(std::move(vector)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("scalar"), std::make_unique("vector")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32()); @@ -1094,22 +1040,17 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Vector_Vector) { auto vector = std::make_unique( "vector", ast::StorageClass::kNone, &vec3); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(vector)); + mod()->AddGlobalVariable(std::move(vector)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("vector"), std::make_unique("vector")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32()); @@ -1124,23 +1065,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Matrix_Scalar) { std::make_unique("scalar", ast::StorageClass::kNone, &f32); auto matrix = std::make_unique( "matrix", ast::StorageClass::kNone, &mat3x2); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(scalar)); - m.AddGlobalVariable(std::move(matrix)); + mod()->AddGlobalVariable(std::move(scalar)); + mod()->AddGlobalVariable(std::move(matrix)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("matrix"), std::make_unique("scalar")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsMatrix()); @@ -1158,23 +1094,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Scalar_Matrix) { std::make_unique("scalar", ast::StorageClass::kNone, &f32); auto matrix = std::make_unique( "matrix", ast::StorageClass::kNone, &mat3x2); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(scalar)); - m.AddGlobalVariable(std::move(matrix)); + mod()->AddGlobalVariable(std::move(scalar)); + mod()->AddGlobalVariable(std::move(matrix)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("scalar"), std::make_unique("matrix")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsMatrix()); @@ -1193,23 +1124,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Matrix_Vector) { "vector", ast::StorageClass::kNone, &vec3); auto matrix = std::make_unique( "matrix", ast::StorageClass::kNone, &mat3x2); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(vector)); - m.AddGlobalVariable(std::move(matrix)); + mod()->AddGlobalVariable(std::move(vector)); + mod()->AddGlobalVariable(std::move(matrix)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("matrix"), std::make_unique("vector")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32()); @@ -1225,23 +1151,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Vector_Matrix) { "vector", ast::StorageClass::kNone, &vec3); auto matrix = std::make_unique( "matrix", ast::StorageClass::kNone, &mat3x2); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(vector)); - m.AddGlobalVariable(std::move(matrix)); + mod()->AddGlobalVariable(std::move(vector)); + mod()->AddGlobalVariable(std::move(matrix)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("vector"), std::make_unique("matrix")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsVector()); EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32()); @@ -1257,23 +1178,18 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Matrix_Matrix) { "mat4x3", ast::StorageClass::kNone, &mat4x3); auto matrix2 = std::make_unique( "mat3x4", ast::StorageClass::kNone, &mat3x4); - - Context ctx; - TypeDeterminer td(&ctx); - - ast::Module m; - m.AddGlobalVariable(std::move(matrix1)); - m.AddGlobalVariable(std::move(matrix2)); + mod()->AddGlobalVariable(std::move(matrix1)); + mod()->AddGlobalVariable(std::move(matrix2)); // Register the global - ASSERT_TRUE(td.Determine(&m)) << td.error(); + ASSERT_TRUE(td()->Determine((mod()))) << td()->error(); ast::BinaryExpression expr( ast::BinaryOp::kMultiply, std::make_unique("mat4x3"), std::make_unique("mat3x4")); - ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + ASSERT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); ASSERT_NE(expr.result_type(), nullptr); ASSERT_TRUE(expr.result_type()->IsMatrix()); @@ -1284,7 +1200,7 @@ TEST_F(TypeDeterminerTest, Expr_Binary_Multiply_Matrix_Matrix) { } using UnaryDerivativeExpressionTest = - testing::TestWithParam; + TypeDeterminerTestWithParam; TEST_P(UnaryDerivativeExpressionTest, Expr_UnaryDerivative) { auto derivative = GetParam(); @@ -1294,20 +1210,15 @@ TEST_P(UnaryDerivativeExpressionTest, Expr_UnaryDerivative) { auto var = std::make_unique("ident", ast::StorageClass::kNone, &vec4); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); - - Context ctx; - TypeDeterminer td(&ctx); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); ast::UnaryDerivativeExpression der( derivative, ast::DerivativeModifier::kNone, std::make_unique("ident")); - EXPECT_TRUE(td.DetermineResultType(&der)); + EXPECT_TRUE(td()->DetermineResultType(&der)); ASSERT_NE(der.result_type(), nullptr); ASSERT_TRUE(der.result_type()->IsVector()); EXPECT_TRUE(der.result_type()->AsVector()->type()->IsF32()); @@ -1319,7 +1230,8 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, ast::UnaryDerivative::kDpdy, ast::UnaryDerivative::kFwidth)); -using UnaryMethodExpressionBoolTest = testing::TestWithParam; +using UnaryMethodExpressionBoolTest = + TypeDeterminerTestWithParam; TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) { auto op = GetParam(); @@ -1328,22 +1240,17 @@ TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) { auto var = std::make_unique("my_var", ast::StorageClass::kNone, &vec3); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); ast::ExpressionList params; params.push_back(std::make_unique("my_var")); ast::UnaryMethodExpression exp(op, std::move(params)); - Context ctx; - TypeDeterminer td(&ctx); - // Register the variable - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); - EXPECT_TRUE(td.DetermineResultType(&exp)); + EXPECT_TRUE(td()->DetermineResultType(&exp)); ASSERT_NE(exp.result_type(), nullptr); EXPECT_TRUE(exp.result_type()->IsBool()); } @@ -1352,7 +1259,8 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, testing::Values(ast::UnaryMethod::kAny, ast::UnaryMethod::kAll)); -using UnaryMethodExpressionVecTest = testing::TestWithParam; +using UnaryMethodExpressionVecTest = + TypeDeterminerTestWithParam; TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) { auto op = GetParam(); @@ -1361,22 +1269,17 @@ TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) { auto var = std::make_unique("my_var", ast::StorageClass::kNone, &vec3); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); ast::ExpressionList params; params.push_back(std::make_unique("my_var")); ast::UnaryMethodExpression exp(op, std::move(params)); - Context ctx; - TypeDeterminer td(&ctx); - // Register the variable - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); - EXPECT_TRUE(td.DetermineResultType(&exp)); + EXPECT_TRUE(td()->DetermineResultType(&exp)); ASSERT_NE(exp.result_type(), nullptr); ASSERT_TRUE(exp.result_type()->IsVector()); EXPECT_TRUE(exp.result_type()->AsVector()->type()->IsBool()); @@ -1389,22 +1292,17 @@ TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Vec) { auto var = std::make_unique("my_var", ast::StorageClass::kNone, &f32); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); ast::ExpressionList params; params.push_back(std::make_unique("my_var")); ast::UnaryMethodExpression exp(op, std::move(params)); - Context ctx; - TypeDeterminer td(&ctx); - // Register the variable - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); - EXPECT_TRUE(td.DetermineResultType(&exp)); + EXPECT_TRUE(td()->DetermineResultType(&exp)); ASSERT_NE(exp.result_type(), nullptr); EXPECT_TRUE(exp.result_type()->IsBool()); } @@ -1421,9 +1319,7 @@ TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) { auto var = std::make_unique("my_var", ast::StorageClass::kNone, &vec3); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(var)); ast::ExpressionList params; params.push_back(std::make_unique("my_var")); @@ -1431,13 +1327,10 @@ TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) { ast::UnaryMethodExpression exp(ast::UnaryMethod::kDot, std::move(params)); - Context ctx; - TypeDeterminer td(&ctx); - // Register the variable - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); - EXPECT_TRUE(td.DetermineResultType(&exp)); + EXPECT_TRUE(td()->DetermineResultType(&exp)); ASSERT_NE(exp.result_type(), nullptr); EXPECT_TRUE(exp.result_type()->IsF32()); } @@ -1451,10 +1344,8 @@ TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) { std::make_unique("v3", ast::StorageClass::kNone, &vec3); auto var2 = std::make_unique("v2", ast::StorageClass::kNone, &vec2); - - ast::Module m; - m.AddGlobalVariable(std::move(var1)); - m.AddGlobalVariable(std::move(var2)); + mod()->AddGlobalVariable(std::move(var1)); + mod()->AddGlobalVariable(std::move(var2)); ast::ExpressionList params; params.push_back(std::make_unique("v3")); @@ -1463,13 +1354,10 @@ TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) { ast::UnaryMethodExpression exp(ast::UnaryMethod::kOuterProduct, std::move(params)); - Context ctx; - TypeDeterminer td(&ctx); - // Register the variable - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); - EXPECT_TRUE(td.DetermineResultType(&exp)); + EXPECT_TRUE(td()->DetermineResultType(&exp)); ASSERT_NE(exp.result_type(), nullptr); ASSERT_TRUE(exp.result_type()->IsMatrix()); auto* mat = exp.result_type()->AsMatrix(); @@ -1478,7 +1366,7 @@ TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) { EXPECT_EQ(mat->columns(), 2u); } -using UnaryOpExpressionTest = testing::TestWithParam; +using UnaryOpExpressionTest = TypeDeterminerTestWithParam; TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) { auto op = GetParam(); @@ -1488,19 +1376,14 @@ TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) { auto var = std::make_unique("ident", ast::StorageClass::kNone, &vec4); - - ast::Module m; - m.AddGlobalVariable(std::move(var)); - - Context ctx; - TypeDeterminer td(&ctx); + mod()->AddGlobalVariable(std::move(var)); // Register the global - EXPECT_TRUE(td.Determine(&m)); + EXPECT_TRUE(td()->Determine((mod()))); ast::UnaryOpExpression der( op, std::make_unique("ident")); - EXPECT_TRUE(td.DetermineResultType(&der)); + EXPECT_TRUE(td()->DetermineResultType(&der)); ASSERT_NE(der.result_type(), nullptr); ASSERT_TRUE(der.result_type()->IsVector()); EXPECT_TRUE(der.result_type()->AsVector()->type()->IsF32()); @@ -1525,10 +1408,9 @@ TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) { stmts.push_back(std::move(stmt)); func->set_body(std::move(stmts)); - ast::Module m; - m.AddFunction(std::move(func)); + mod()->AddFunction(std::move(func)); - EXPECT_TRUE(td()->Determine(&m)) << td()->error(); + EXPECT_TRUE(td()->Determine((mod()))) << td()->error(); EXPECT_EQ(var_ptr->storage_class(), ast::StorageClass::kFunction); } @@ -1547,10 +1429,9 @@ TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) { stmts.push_back(std::move(stmt)); func->set_body(std::move(stmts)); - ast::Module m; - m.AddFunction(std::move(func)); + mod()->AddFunction(std::move(func)); - EXPECT_TRUE(td()->Determine(&m)) << td()->error(); + EXPECT_TRUE(td()->Determine((mod()))) << td()->error(); EXPECT_EQ(var_ptr->storage_class(), ast::StorageClass::kNone); } @@ -1567,10 +1448,9 @@ TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) { stmts.push_back(std::move(stmt)); func->set_body(std::move(stmts)); - ast::Module m; - m.AddFunction(std::move(func)); + mod()->AddFunction(std::move(func)); - EXPECT_FALSE(td()->Determine(&m)); + EXPECT_FALSE(td()->Determine((mod()))); EXPECT_EQ(td()->error(), "function variable has a non-function storage class"); }