diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 93f3133ebb..59a36c0690 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1527,20 +1527,36 @@ uint32_t Builder::GenerateCastExpression(ast::CastExpression* cast) { auto* from_type = cast->expr()->result_type()->UnwrapPtrIfNeeded(); spv::Op op = spv::Op::OpNop; - if (from_type->IsI32() && to_type->IsF32()) { + if ((from_type->IsI32() && to_type->IsF32()) || + (from_type->is_signed_integer_vector() && to_type->is_float_vector())) { op = spv::Op::OpConvertSToF; - } else if (from_type->IsU32() && to_type->IsF32()) { + } else if ((from_type->IsU32() && to_type->IsF32()) || + (from_type->is_unsigned_integer_vector() && + to_type->is_float_vector())) { op = spv::Op::OpConvertUToF; - } else if (from_type->IsF32() && to_type->IsI32()) { + } else if ((from_type->IsF32() && to_type->IsI32()) || + (from_type->is_float_vector() && + to_type->is_signed_integer_vector())) { op = spv::Op::OpConvertFToS; - } else if (from_type->IsF32() && to_type->IsU32()) { + } else if ((from_type->IsF32() && to_type->IsU32()) || + (from_type->is_float_vector() && + to_type->is_unsigned_integer_vector())) { op = spv::Op::OpConvertFToU; } else if ((from_type->IsU32() && to_type->IsU32()) || (from_type->IsI32() && to_type->IsI32()) || - (from_type->IsF32() && to_type->IsF32())) { + (from_type->IsF32() && to_type->IsF32()) || + (from_type->is_unsigned_integer_vector() && + to_type->is_unsigned_integer_vector()) || + (from_type->is_signed_integer_vector() && + to_type->is_signed_integer_vector()) || + (from_type->is_float_vector() && to_type->is_float_vector())) { op = spv::Op::OpCopyObject; } else if ((from_type->IsI32() && to_type->IsU32()) || - (from_type->IsU32() && to_type->IsI32())) { + (from_type->IsU32() && to_type->IsI32()) || + (from_type->is_signed_integer_vector() && + to_type->is_unsigned_integer_vector()) || + (from_type->is_unsigned_integer_vector() && + to_type->is_integer_scalar_or_vector())) { op = spv::Op::OpBitcast; } diff --git a/src/writer/spirv/builder_cast_expression_test.cc b/src/writer/spirv/builder_cast_expression_test.cc index c559ec9b4c..675bb203fd 100644 --- a/src/writer/spirv/builder_cast_expression_test.cc +++ b/src/writer/spirv/builder_cast_expression_test.cc @@ -22,6 +22,7 @@ #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/u32_type.h" +#include "src/ast/type/vector_type.h" #include "src/ast/uint_literal.h" #include "src/context.h" #include "src/type_determiner.h" @@ -329,6 +330,224 @@ TEST_F(BuilderTest, Cast_F32ToF32) { )"); } +TEST_F(BuilderTest, Cast_Vectors_I32_to_F32) { + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &ivec3); + + ast::CastExpression cast(&fvec3, + std::make_unique("i")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeFloat 32 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertSToF %7 %9 +)"); +} + +TEST_F(BuilderTest, Cast_Vectors_U32_to_F32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &uvec3); + + ast::CastExpression cast(&fvec3, + std::make_unique("i")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeFloat 32 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertUToF %7 %9 +)"); +} + +TEST_F(BuilderTest, Cast_Vectors_F32_to_I32) { + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &fvec3); + + ast::CastExpression cast(&ivec3, + std::make_unique("i")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 1 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertFToS %7 %9 +)"); +} + +TEST_F(BuilderTest, Cast_Vectors_F32_to_U32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::F32Type f32; + ast::type::VectorType fvec3(&f32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &fvec3); + + ast::CastExpression cast(&uvec3, + std::make_unique("i")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 0 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpConvertFToU %7 %9 +)"); +} + +TEST_F(BuilderTest, Cast_Vectors_U32_to_U32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &uvec3); + + ast::CastExpression cast(&uvec3, + std::make_unique("i")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%7 = OpLoad %3 %1 +%6 = OpCopyObject %3 %7 +)"); +} + +TEST_F(BuilderTest, Cast_Vectors_I32_to_U32) { + ast::type::U32Type u32; + ast::type::VectorType uvec3(&u32, 3); + ast::type::I32Type i32; + ast::type::VectorType ivec3(&i32, 3); + + auto var = + std::make_unique("i", ast::StorageClass::kPrivate, &ivec3); + + ast::CastExpression cast(&uvec3, + std::make_unique("i")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + EXPECT_EQ(b.GenerateCastExpression(&cast), 6u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeInt 32 0 +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%6 = OpBitcast %7 %9 +)"); +} + } // namespace } // namespace spirv } // namespace writer