diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc index cae468f93a..0e29f5f2b2 100644 --- a/src/reader/spirv/function_cfg_test.cc +++ b/src/reader/spirv/function_cfg_test.cc @@ -4639,9 +4639,10 @@ TEST_F(SpvParserTest, ClassifyCFGEdge_IfBreak_BypassesMerge_IsError) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 20 to block 99 is an invalid exit from " - "construct starting at block 10; branch bypasses merge block 50")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 20 to block 99 is an invalid exit from " + "construct starting at block 10; branch bypasses merge block 50")); } TEST_F(SpvParserTest, ClassifyCFGEdges_SwitchBreak_FromSwitchCaseDirect) { @@ -4832,9 +4833,10 @@ TEST_F(SpvParserTest, ClassifyCFGEdges_SwitchBreak_BypassesMerge_IsError) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 20 to block 99 is an invalid exit from " - "construct starting at block 10; branch bypasses merge block 50")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 20 to block 99 is an invalid exit from " + "construct starting at block 10; branch bypasses merge block 50")); } TEST_F(SpvParserTest, ClassifyCFGEdges_SwitchBreak_FromNestedLoop_IsError) { @@ -4866,9 +4868,10 @@ TEST_F(SpvParserTest, ClassifyCFGEdges_SwitchBreak_FromNestedLoop_IsError) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 30 to block 99 is an invalid exit from " - "construct starting at block 20; branch bypasses merge block 80")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 30 to block 99 is an invalid exit from " + "construct starting at block 20; branch bypasses merge block 80")); } TEST_F(SpvParserTest, ClassifyCFGEdges_SwitchBreak_FromNestedSwitch_IsError) { @@ -4897,9 +4900,10 @@ TEST_F(SpvParserTest, ClassifyCFGEdges_SwitchBreak_FromNestedSwitch_IsError) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 30 to block 99 is an invalid exit from " - "construct starting at block 20; branch bypasses merge block 80")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 30 to block 99 is an invalid exit from " + "construct starting at block 20; branch bypasses merge block 80")); } TEST_F(SpvParserTest, ClassifyCFGEdges_LoopBreak_FromLoopBody) { @@ -5145,9 +5149,10 @@ TEST_F(SpvParserTest, ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 30 to block 99 is an invalid exit from " - "construct starting at block 20; branch bypasses merge block 50")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 30 to block 99 is an invalid exit from " + "construct starting at block 20; branch bypasses merge block 50")); } TEST_F(SpvParserTest, @@ -5181,9 +5186,10 @@ TEST_F(SpvParserTest, ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 45 to block 99 is an invalid exit from " - "construct starting at block 40; branch bypasses merge block 50")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 45 to block 99 is an invalid exit from " + "construct starting at block 40; branch bypasses merge block 50")); } TEST_F(SpvParserTest, ClassifyCFGEdges_LoopContinue_LoopBodyToContinue) { @@ -5935,9 +5941,10 @@ TEST_F(SpvParserTest, ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 30 to block 60 is an invalid exit from " - "construct starting at block 20; branch bypasses continue target 50")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 30 to block 60 is an invalid exit from " + "construct starting at block 20; branch bypasses continue target 50")); } TEST_F(SpvParserTest, @@ -5968,9 +5975,10 @@ TEST_F(SpvParserTest, ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(FlowClassifyCFGEdges(&fe)); - EXPECT_THAT(p->error(), - Eq("Branch from block 50 to block 60 is an invalid exit from " - "construct starting at block 50; branch bypasses merge block 80")); + EXPECT_THAT( + p->error(), + Eq("Branch from block 50 to block 60 is an invalid exit from " + "construct starting at block 50; branch bypasses merge block 80")); } TEST_F(SpvParserTest, ClassifyCFGEdges_TooManyBackedges) { diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index e8f0534e55..c5d4db612d 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1231,7 +1231,15 @@ uint32_t Builder::GenerateCastExpression(ast::CastExpression* cast) { op = spv::Op::OpConvertFToS; } else if (from_type->IsF32() && to_type->IsU32()) { op = spv::Op::OpConvertFToU; + } else if ((from_type->IsU32() && to_type->IsU32()) || + (from_type->IsI32() && to_type->IsI32()) || + (from_type->IsF32() && to_type->IsF32())) { + op = spv::Op::OpCopyObject; + } else if ((from_type->IsI32() && to_type->IsU32()) || + (from_type->IsU32() && to_type->IsI32())) { + op = spv::Op::OpBitcast; } + if (op == spv::Op::OpNop) { error_ = "unable to determine conversion type for cast, from: " + from_type->type_name() + " to: " + to_type->type_name(); diff --git a/src/writer/spirv/builder_cast_expression_test.cc b/src/writer/spirv/builder_cast_expression_test.cc index c6b6fcb88d..ecb248cc4a 100644 --- a/src/writer/spirv/builder_cast_expression_test.cc +++ b/src/writer/spirv/builder_cast_expression_test.cc @@ -21,6 +21,8 @@ #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/u32_type.h" +#include "src/ast/uint_literal.h" #include "src/context.h" #include "src/type_determiner.h" #include "src/writer/spirv/builder.h" @@ -33,9 +35,57 @@ namespace { using BuilderTest = testing::Test; -TEST_F(BuilderTest, DISABLED_Cast_FloatToU32) {} +TEST_F(BuilderTest, Cast_FloatToU32) { + ast::type::U32Type u32; + ast::type::F32Type f32; -TEST_F(BuilderTest, DISABLED_Cast_FloatToI32) {} + ast::CastExpression cast(&u32, + std::make_unique( + std::make_unique(&f32, 2.4))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 +%3 = OpTypeFloat 32 +%4 = OpConstant %3 2.4000001 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertFToU %2 %4 +)"); +} + +TEST_F(BuilderTest, Cast_FloatToI32) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + ast::CastExpression cast(&i32, + std::make_unique( + std::make_unique(&f32, 2.4))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeFloat 32 +%4 = OpConstant %3 2.4000001 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertFToS %2 %4 +)"); +} TEST_F(BuilderTest, Cast_I32ToFloat) { ast::type::I32Type i32; @@ -63,7 +113,31 @@ TEST_F(BuilderTest, Cast_I32ToFloat) { )"); } -TEST_F(BuilderTest, DISABLED_Cast_U32ToFloat) {} +TEST_F(BuilderTest, Cast_U32ToFloat) { + ast::type::U32Type u32; + ast::type::F32Type f32; + + ast::CastExpression cast(&f32, + std::make_unique( + std::make_unique(&u32, 2))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertUToF %2 %4 +)"); +} TEST_F(BuilderTest, Cast_WithLoad) { ast::type::F32Type f32; @@ -100,9 +174,160 @@ TEST_F(BuilderTest, Cast_WithLoad) { )"); } -TEST_F(BuilderTest, DISABLED_Cast_WithAlias) {} +TEST_F(BuilderTest, Cast_WithAlias) { + ast::type::I32Type i32; + ast::type::F32Type f32; -// TODO(dsinclair): Are here i32 -> u32 and u32->i32 casts? + // type Int = i32 + // cast(1.f) + + ast::type::AliasType alias("Int", &i32); + + ast::CastExpression cast(&alias, + std::make_unique( + std::make_unique(&f32, 2.3))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeFloat 32 +%4 = OpConstant %3 2.29999995 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpConvertFToS %2 %4 +)"); +} + +TEST_F(BuilderTest, Cast_I32ToU32) { + ast::type::U32Type u32; + ast::type::I32Type i32; + + ast::CastExpression cast(&u32, + std::make_unique( + std::make_unique(&i32, 2))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 +%3 = OpTypeInt 32 1 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpBitcast %2 %4 +)"); +} + +TEST_F(BuilderTest, Cast_U32ToI32) { + ast::type::U32Type u32; + ast::type::I32Type i32; + + ast::CastExpression cast(&i32, + std::make_unique( + std::make_unique(&u32, 2))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpBitcast %2 %4 +)"); +} + +TEST_F(BuilderTest, Cast_I32ToI32) { + ast::type::I32Type i32; + + ast::CastExpression cast(&i32, + std::make_unique( + std::make_unique(&i32, 2))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpConstant %2 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} + +TEST_F(BuilderTest, Cast_U32ToU32) { + ast::type::U32Type u32; + + ast::CastExpression cast(&u32, + std::make_unique( + std::make_unique(&u32, 2))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 +%3 = OpConstant %2 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} + +TEST_F(BuilderTest, Cast_F32ToF32) { + ast::type::F32Type f32; + + ast::CastExpression cast(&f32, + std::make_unique( + std::make_unique(&f32, 2.0))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + ASSERT_TRUE(td.DetermineResultType(&cast)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + EXPECT_EQ(b.GenerateCastExpression(&cast), 1u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%3 = OpConstant %2 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpCopyObject %2 %3 +)"); +} } // namespace } // namespace spirv