diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9db881d348..de4d2d4b66 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -673,6 +673,7 @@ if(${TINT_BUILD_TESTS}) resolver/assignment_validation_test.cc resolver/atomics_test.cc resolver/atomics_validation_test.cc + resolver/bitcast_validation_test.cc resolver/builtins_validation_test.cc resolver/call_test.cc resolver/call_validation_test.cc diff --git a/src/resolver/bitcast_validation_test.cc b/src/resolver/bitcast_validation_test.cc new file mode 100644 index 0000000000..d4ce0823db --- /dev/null +++ b/src/resolver/bitcast_validation_test.cc @@ -0,0 +1,228 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/bitcast_expression.h" +#include "src/resolver/resolver.h" +#include "src/resolver/resolver_test_helper.h" + +#include "gmock/gmock.h" + +namespace tint { +namespace resolver { +namespace { + +struct Type { + template + static constexpr Type Create() { + return Type{builder::DataType::AST, builder::DataType::Sem, + builder::DataType::Expr}; + } + + builder::ast_type_func_ptr ast; + builder::sem_type_func_ptr sem; + builder::ast_expr_func_ptr expr; +}; + +static constexpr Type kNumericScalars[] = { + Type::Create(), + Type::Create(), + Type::Create(), +}; +static constexpr Type kVec2NumericScalars[] = { + Type::Create>(), + Type::Create>(), + Type::Create>(), +}; +static constexpr Type kVec3NumericScalars[] = { + Type::Create>(), + Type::Create>(), + Type::Create>(), +}; +static constexpr Type kVec4NumericScalars[] = { + Type::Create>(), + Type::Create>(), + Type::Create>(), +}; +static constexpr Type kInvalid[] = { + // A non-exhaustive selection of uncastable types + Type::Create(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>(), + Type::Create>>(), + Type::Create>>(), +}; + +using ResolverBitcastValidationTest = + ResolverTestWithParam>; + +//////////////////////////////////////////////////////////////////////////////// +// Valid bitcasts +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestPass, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0)); + WrapInFunction(cast); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + EXPECT_EQ(TypeOf(cast), dst.sem(*this)); +} +INSTANTIATE_TEST_SUITE_P(Scalars, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kNumericScalars), + testing::ValuesIn(kNumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec2, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kVec2NumericScalars), + testing::ValuesIn(kVec2NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec3, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kVec3NumericScalars), + testing::ValuesIn(kVec3NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec4, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kVec4NumericScalars), + testing::ValuesIn(kVec4NumericScalars))); + +//////////////////////////////////////////////////////////////////////////////// +// Invalid source type for bitcasts +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src")); + WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast); + + auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) + + "' cannot be bitcast"; + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), expected); +} +INSTANTIATE_TEST_SUITE_P(Scalars, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kNumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec2, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kVec2NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec3, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kVec3NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec4, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kVec4NumericScalars))); + +//////////////////////////////////////////////////////////////////////////////// +// Invalid target type for bitcasts +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + // Use an alias so we can put a Source on the bitcast type + Alias("T", dst.ast(*this)); + WrapInFunction( + Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0))); + + auto expected = "12:34 error: cannot bitcast to '" + + dst.sem(*this)->FriendlyName(Symbols()) + "'"; + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), expected); +} +INSTANTIATE_TEST_SUITE_P(Scalars, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kNumericScalars), + testing::ValuesIn(kInvalid))); +INSTANTIATE_TEST_SUITE_P( + Vec2, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kVec2NumericScalars), + testing::ValuesIn(kInvalid))); +INSTANTIATE_TEST_SUITE_P( + Vec3, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kVec3NumericScalars), + testing::ValuesIn(kInvalid))); +INSTANTIATE_TEST_SUITE_P( + Vec4, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kVec4NumericScalars), + testing::ValuesIn(kInvalid))); + +//////////////////////////////////////////////////////////////////////////////// +// Incompatible bitcast, but both src and dst types are valid +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestIncompatible, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0))); + + auto expected = "12:34 error: cannot bitcast from '" + + src.sem(*this)->FriendlyName(Symbols()) + "' to '" + + dst.sem(*this)->FriendlyName(Symbols()) + "'"; + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), expected); +} +INSTANTIATE_TEST_SUITE_P( + ScalarToVec2, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kNumericScalars), + testing::ValuesIn(kVec2NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec2ToVec3, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kVec2NumericScalars), + testing::ValuesIn(kVec3NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec3ToVec4, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kVec3NumericScalars), + testing::ValuesIn(kVec4NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec4ToScalar, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kVec4NumericScalars), + testing::ValuesIn(kNumericScalars))); + +} // namespace +} // namespace resolver +} // namespace tint diff --git a/src/resolver/dependency_graph.cc b/src/resolver/dependency_graph.cc index 947be993dd..8627dcbdcd 100644 --- a/src/resolver/dependency_graph.cc +++ b/src/resolver/dependency_graph.cc @@ -329,6 +329,9 @@ class DependencyScanner { utils::Lookup(graph_.resolved_symbols, call->target.type)); } } + if (auto* cast = expr->As()) { + TraverseType(cast->type); + } return ast::TraverseAction::Descend; }); } diff --git a/src/resolver/dependency_graph_test.cc b/src/resolver/dependency_graph_test.cc index 5fd048d428..dcbcd7272d 100644 --- a/src/resolver/dependency_graph_test.cc +++ b/src/resolver/dependency_graph_test.cc @@ -1282,6 +1282,7 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { Block(Assign(V, V)), // Else(V, // Block(Assign(V, V)))), // + Ignore(Bitcast(T, V)), // For(Decl(Var(Sym(), T, V)), // Equal(V, V), // Assign(V, V), // diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc index 57aea322be..367b209aaf 100644 --- a/src/resolver/ptr_ref_validation_test.cc +++ b/src/resolver/ptr_ref_validation_test.cc @@ -171,19 +171,6 @@ TEST_F(ResolverPtrRefValidationTest, InferredPtrAccessMismatch) { "'ptr'"); } -TEST_F(ResolverTest, Expr_Bitcast_ptr) { - auto* vf = Var("vf", ty.f32()); - auto* bitcast = create( - Source{{12, 34}}, ty.pointer(ast::StorageClass::kFunction), - Expr("vf")); - auto* ip = - Const("ip", ty.pointer(ast::StorageClass::kFunction), bitcast); - WrapInFunction(Decl(vf), Decl(ip)); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: cannot cast to a pointer"); -} - } // namespace } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index c53de55f22..abe47c4d18 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1221,15 +1221,17 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { if (!ty) { return nullptr; } - if (ty->Is()) { - AddError("cannot cast to a pointer", expr->source); - return nullptr; - } auto val = EvaluateConstantValue(expr, ty); auto* sem = builder_->create(expr, ty, current_statement_, val); + sem->Behaviors() = inner->Behaviors(); + + if (!ValidateBitcast(expr, ty)) { + return nullptr; + } + return sem; } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 2de937e3eb..e7f6deb393 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -238,6 +238,7 @@ class Resolver { bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); bool ValidateAtomicVariable(const sem::Variable* var); bool ValidateAssignment(const ast::AssignmentStatement* a); + bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to); bool ValidateBreakStatement(const sem::Statement* stmt); bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, const sem::Type* storage_type, diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index 9f5ff52801..bc50d93961 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -171,6 +171,9 @@ using alias2 = alias; template using alias3 = alias; +template +struct ptr {}; + using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, int elem_value); @@ -387,6 +390,36 @@ struct DataType> { } }; +/// Helper for building pointer types and expressions +template +struct DataType> { + /// true if the pointer type is a composite type + static constexpr bool is_composite = false; + + /// @param b the ProgramBuilder + /// @return a new AST alias type + static inline const ast::Type* AST(ProgramBuilder& b) { + return b.create(DataType::AST(b), + ast::StorageClass::kPrivate, + ast::Access::kReadWrite); + } + /// @param b the ProgramBuilder + /// @return the semantic aliased type + static inline const sem::Type* Sem(ProgramBuilder& b) { + return b.create(DataType::Sem(b), + ast::StorageClass::kPrivate, + ast::Access::kReadWrite); + } + + /// @param b the ProgramBuilder + /// @return a new AST expression of the alias type + static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) { + auto sym = b.Symbols().New("global_for_ptr"); + b.Global(sym, DataType::AST(b), ast::StorageClass::kPrivate); + return b.AddressOf(sym); + } +}; + /// Helper for building array types and expressions template struct DataType> { @@ -401,7 +434,14 @@ struct DataType> { /// @param b the ProgramBuilder /// @return the semantic array type static inline const sem::Type* Sem(ProgramBuilder& b) { - return b.create(DataType::Sem(b), N); + auto* el = DataType::Sem(b); + return b.create( + /* element */ el, + /* count */ N, + /* align */ el->Align(), + /* size */ el->Size(), + /* stride */ el->Align(), + /* implicit_stride */ el->Align()); } /// @param b the ProgramBuilder /// @param elem_value the value each element in the array will be initialized diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc index f37633589a..7c6afefb68 100644 --- a/src/resolver/resolver_validation.cc +++ b/src/resolver/resolver_validation.cc @@ -1347,6 +1347,36 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) { return true; } +bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, + const sem::Type* to) { + auto* from = TypeOf(cast->expr)->UnwrapRef(); + if (!from->is_numeric_scalar_or_vector()) { + AddError("'" + TypeNameOf(from) + "' cannot be bitcast", + cast->expr->source); + return false; + } + if (!to->is_numeric_scalar_or_vector()) { + AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source); + return false; + } + + auto width = [&](const sem::Type* ty) { + if (auto* vec = ty->As()) { + return vec->Width(); + } + return 1u; + }; + + if (width(from) != width(to)) { + AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" + + TypeNameOf(to) + "'", + cast->source); + return false; + } + + return true; +} + bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) { if (!stmt->FindFirstParent()) { AddError("break statement must be in a loop or switch case", diff --git a/src/writer/glsl/generator_impl_binary_test.cc b/src/writer/glsl/generator_impl_binary_test.cc index 5c397e974a..5ef5339989 100644 --- a/src/writer/glsl/generator_impl_binary_test.cc +++ b/src/writer/glsl/generator_impl_binary_test.cc @@ -452,36 +452,6 @@ bool a = (tint_tmp); )"); } -TEST_F(GlslGeneratorImplTest_Binary, Bitcast_WithLogical) { - // as(a && (b || c)) - - Global("a", ty.bool_(), ast::StorageClass::kPrivate); - Global("b", ty.bool_(), ast::StorageClass::kPrivate); - Global("c", ty.bool_(), ast::StorageClass::kPrivate); - - auto* expr = create( - ty.i32(), create( - ast::BinaryOp::kLogicalAnd, Expr("a"), - create(ast::BinaryOp::kLogicalOr, - Expr("b"), Expr("c")))); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - std::stringstream out; - ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(gen.result(), R"(bool tint_tmp = a; -if (tint_tmp) { - bool tint_tmp_1 = b; - if (!tint_tmp_1) { - tint_tmp_1 = c; - } - tint_tmp = (tint_tmp_1); -} -)"); - EXPECT_EQ(out.str(), R"(int((tint_tmp)))"); -} - TEST_F(GlslGeneratorImplTest_Binary, Call_WithLogical) { // foo(a && b, c || d, (a || c) && (b || d)) diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index 44d27b77a1..c14a89a6c6 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc @@ -452,36 +452,6 @@ bool a = (tint_tmp); )"); } -TEST_F(HlslGeneratorImplTest_Binary, Bitcast_WithLogical) { - // as(a && (b || c)) - - Global("a", ty.bool_(), ast::StorageClass::kPrivate); - Global("b", ty.bool_(), ast::StorageClass::kPrivate); - Global("c", ty.bool_(), ast::StorageClass::kPrivate); - - auto* expr = create( - ty.i32(), create( - ast::BinaryOp::kLogicalAnd, Expr("a"), - create(ast::BinaryOp::kLogicalOr, - Expr("b"), Expr("c")))); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - std::stringstream out; - ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(gen.result(), R"(bool tint_tmp = a; -if (tint_tmp) { - bool tint_tmp_1 = b; - if (!tint_tmp_1) { - tint_tmp_1 = c; - } - tint_tmp = (tint_tmp_1); -} -)"); - EXPECT_EQ(out.str(), R"(asint((tint_tmp)))"); -} - TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) { // foo(a && b, c || d, (a || c) && (b || d)) diff --git a/test/BUILD.gn b/test/BUILD.gn index a4467af75e..cad09fb28f 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -236,6 +236,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") { "../src/resolver/assignment_validation_test.cc", "../src/resolver/atomics_test.cc", "../src/resolver/atomics_validation_test.cc", + "../src/resolver/bitcast_validation_test.cc", "../src/resolver/builtins_validation_test.cc", "../src/resolver/call_test.cc", "../src/resolver/call_validation_test.cc",