// Copyright 2021 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/ast/bitcast_expression.h" #include "src/tint/resolver/resolver.h" #include "src/tint/resolver/resolver_test_helper.h" #include "gmock/gmock.h" namespace tint::resolver { namespace { struct Type { template static constexpr Type Create() { return Type{builder::DataType::AST, builder::DataType::Sem, builder::DataType::ExprFromDouble}; } builder::ast_type_func_ptr ast; builder::sem_type_func_ptr sem; builder::ast_expr_from_double_func_ptr expr; }; static constexpr Type kNumericScalars[] = { Type::Create(), Type::Create(), Type::Create(), }; static constexpr Type kVec2NumericScalars[] = { Type::Create>(), Type::Create>(), Type::Create>(), }; static constexpr Type kVec3NumericScalars[] = { Type::Create>(), Type::Create>(), Type::Create>(), }; static constexpr Type kVec4NumericScalars[] = { Type::Create>(), Type::Create>(), Type::Create>(), }; static constexpr Type kInvalid[] = { // A non-exhaustive selection of uncastable types Type::Create(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>(), Type::Create>>(), Type::Create>>(), }; using ResolverBitcastValidationTest = ResolverTestWithParam>; //////////////////////////////////////////////////////////////////////////////// // Valid bitcasts //////////////////////////////////////////////////////////////////////////////// using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest; TEST_P(ResolverBitcastValidationTestPass, Test) { auto src = std::get<0>(GetParam()); auto dst = std::get<1>(GetParam()); auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0)); WrapInFunction(cast); ASSERT_TRUE(r()->Resolve()) << r()->error(); EXPECT_EQ(TypeOf(cast), dst.sem(*this)); } INSTANTIATE_TEST_SUITE_P(Scalars, ResolverBitcastValidationTestPass, testing::Combine(testing::ValuesIn(kNumericScalars), testing::ValuesIn(kNumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec2, ResolverBitcastValidationTestPass, testing::Combine(testing::ValuesIn(kVec2NumericScalars), testing::ValuesIn(kVec2NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec3, ResolverBitcastValidationTestPass, testing::Combine(testing::ValuesIn(kVec3NumericScalars), testing::ValuesIn(kVec3NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec4, ResolverBitcastValidationTestPass, testing::Combine(testing::ValuesIn(kVec4NumericScalars), testing::ValuesIn(kVec4NumericScalars))); //////////////////////////////////////////////////////////////////////////////// // Invalid source type for bitcasts //////////////////////////////////////////////////////////////////////////////// using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest; TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) { auto src = std::get<0>(GetParam()); auto dst = std::get<1>(GetParam()); auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src")); WrapInFunction(Let("src", src.expr(*this, 0)), cast); auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) + "' cannot be bitcast"; EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), expected); } INSTANTIATE_TEST_SUITE_P(Scalars, ResolverBitcastValidationTestInvalidSrcTy, testing::Combine(testing::ValuesIn(kInvalid), testing::ValuesIn(kNumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec2, ResolverBitcastValidationTestInvalidSrcTy, testing::Combine(testing::ValuesIn(kInvalid), testing::ValuesIn(kVec2NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec3, ResolverBitcastValidationTestInvalidSrcTy, testing::Combine(testing::ValuesIn(kInvalid), testing::ValuesIn(kVec3NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec4, ResolverBitcastValidationTestInvalidSrcTy, testing::Combine(testing::ValuesIn(kInvalid), testing::ValuesIn(kVec4NumericScalars))); //////////////////////////////////////////////////////////////////////////////// // Invalid target type for bitcasts //////////////////////////////////////////////////////////////////////////////// using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest; TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) { auto src = std::get<0>(GetParam()); auto dst = std::get<1>(GetParam()); // Use an alias so we can put a Source on the bitcast type Alias("T", dst.ast(*this)); WrapInFunction(Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0))); auto expected = "12:34 error: cannot bitcast to '" + dst.sem(*this)->FriendlyName(Symbols()) + "'"; EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), expected); } INSTANTIATE_TEST_SUITE_P(Scalars, ResolverBitcastValidationTestInvalidDstTy, testing::Combine(testing::ValuesIn(kNumericScalars), testing::ValuesIn(kInvalid))); INSTANTIATE_TEST_SUITE_P(Vec2, ResolverBitcastValidationTestInvalidDstTy, testing::Combine(testing::ValuesIn(kVec2NumericScalars), testing::ValuesIn(kInvalid))); INSTANTIATE_TEST_SUITE_P(Vec3, ResolverBitcastValidationTestInvalidDstTy, testing::Combine(testing::ValuesIn(kVec3NumericScalars), testing::ValuesIn(kInvalid))); INSTANTIATE_TEST_SUITE_P(Vec4, ResolverBitcastValidationTestInvalidDstTy, testing::Combine(testing::ValuesIn(kVec4NumericScalars), testing::ValuesIn(kInvalid))); //////////////////////////////////////////////////////////////////////////////// // Incompatible bitcast, but both src and dst types are valid //////////////////////////////////////////////////////////////////////////////// using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest; TEST_P(ResolverBitcastValidationTestIncompatible, Test) { auto src = std::get<0>(GetParam()); auto dst = std::get<1>(GetParam()); WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0))); auto expected = "12:34 error: cannot bitcast from '" + src.sem(*this)->FriendlyName(Symbols()) + "' to '" + dst.sem(*this)->FriendlyName(Symbols()) + "'"; EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), expected); } INSTANTIATE_TEST_SUITE_P(ScalarToVec2, ResolverBitcastValidationTestIncompatible, testing::Combine(testing::ValuesIn(kNumericScalars), testing::ValuesIn(kVec2NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec2ToVec3, ResolverBitcastValidationTestIncompatible, testing::Combine(testing::ValuesIn(kVec2NumericScalars), testing::ValuesIn(kVec3NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec3ToVec4, ResolverBitcastValidationTestIncompatible, testing::Combine(testing::ValuesIn(kVec3NumericScalars), testing::ValuesIn(kVec4NumericScalars))); INSTANTIATE_TEST_SUITE_P(Vec4ToScalar, ResolverBitcastValidationTestIncompatible, testing::Combine(testing::ValuesIn(kVec4NumericScalars), testing::ValuesIn(kNumericScalars))); } // namespace } // namespace tint::resolver