diff --git a/src/writer/spirv/builder_assign_test.cc b/src/writer/spirv/builder_assign_test.cc index 99c3af730b..7847d065b6 100644 --- a/src/writer/spirv/builder_assign_test.cc +++ b/src/writer/spirv/builder_assign_test.cc @@ -15,12 +15,20 @@ #include #include "gtest/gtest.h" +#include "src/ast/array_accessor_expression.h" #include "src/ast/assignment_statement.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" +#include "src/ast/int_literal.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/struct.h" +#include "src/ast/struct_member.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" +#include "src/ast/type_constructor_expression.h" #include "src/context.h" #include "src/type_determiner.h" #include "src/writer/spirv/builder.h" @@ -35,7 +43,6 @@ using BuilderTest = testing::Test; TEST_F(BuilderTest, Assign_Var) { ast::type::F32Type f32; - ast::type::VectorType vec(&f32, 3); ast::Variable v("var", ast::StorageClass::kOutput, &f32); @@ -70,9 +77,214 @@ TEST_F(BuilderTest, Assign_Var) { )"); } -TEST_F(BuilderTest, DISABLED_Assign_StructMember) {} +TEST_F(BuilderTest, Assign_StructMember) { + ast::type::F32Type f32; -TEST_F(BuilderTest, DISABLED_Assign_Vector) {} + // my_struct { + // a : f32 + // b : f32 + // } + // var ident : my_struct + // ident.b = 4.0; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &f32, std::move(decos))); + members.push_back( + std::make_unique("b", &f32, std::move(decos))); + + auto s = std::make_unique(ast::StructDecoration::kNone, + std::move(members)); + ast::type::StructType s_type(std::move(s)); + s_type.set_name("my_struct"); + + ast::Variable v("ident", ast::StorageClass::kFunction, &s_type); + + auto ident = std::make_unique( + std::make_unique("ident"), + std::make_unique("b")); + + auto val = std::make_unique( + std::make_unique(&f32, 4.0f)); + + ast::AssignmentStatement assign(std::move(ident), std::move(val)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + + ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error(); + EXPECT_FALSE(b.has_error()); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeStruct %4 %4 +%2 = OpTypePointer Function %3 +%1 = OpVariable %2 Function +%5 = OpTypeInt 32 0 +%6 = OpConstant %5 1 +%7 = OpTypePointer Function %4 +%9 = OpConstant %4 4 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%8 = OpAccessChain %7 %1 %6 +OpStore %8 %9 +)"); +} + +TEST_F(BuilderTest, Assign_Vector) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + ast::Variable v("var", ast::StorageClass::kOutput, &vec3); + + auto ident = std::make_unique("var"); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 3.0f))); + + auto val = + std::make_unique(&vec3, std::move(vals)); + + ast::AssignmentStatement assign(std::move(ident), std::move(val)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + + ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error(); + EXPECT_FALSE(b.has_error()); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Output %3 +%1 = OpVariable %2 Output +%5 = OpConstant %4 1 +%6 = OpConstant %4 3 +%7 = OpConstantComposite %3 %5 %5 %6 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpStore %1 %7 +)"); +} + +TEST_F(BuilderTest, Assign_Vector_MemberByName) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + // var.y = 1 + + ast::Variable v("var", ast::StorageClass::kOutput, &vec3); + + auto ident = std::make_unique( + std::make_unique("var"), + std::make_unique("y")); + auto val = std::make_unique( + std::make_unique(&f32, 1.0f)); + + ast::AssignmentStatement assign(std::move(ident), std::move(val)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + + ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error(); + EXPECT_FALSE(b.has_error()); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Output %3 +%1 = OpVariable %2 Output +%5 = OpTypeInt 32 0 +%6 = OpConstant %5 1 +%7 = OpTypePointer Output %4 +%9 = OpConstant %4 1 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%8 = OpAccessChain %7 %1 %6 +OpStore %8 %9 +)"); +} + +TEST_F(BuilderTest, Assign_Vector_MemberByIndex) { + ast::type::I32Type i32; + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + // var[1] = 1 + + ast::Variable v("var", ast::StorageClass::kOutput, &vec3); + + auto ident = std::make_unique( + std::make_unique("var"), + std::make_unique( + std::make_unique(&i32, 1))); + auto val = std::make_unique( + std::make_unique(&f32, 1.0f)); + + ast::AssignmentStatement assign(std::move(ident), std::move(val)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&v); + + ASSERT_TRUE(td.DetermineResultType(&assign)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_TRUE(b.GenerateAssignStatement(&assign)) << b.error(); + EXPECT_FALSE(b.has_error()); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Output %3 +%1 = OpVariable %2 Output +%5 = OpTypeInt 32 1 +%6 = OpConstant %5 1 +%7 = OpTypePointer Output %4 +%9 = OpConstant %4 1 +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%8 = OpAccessChain %7 %1 %6 +OpStore %8 %9 +)"); +} } // namespace } // namespace spirv