diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 826a4b393f..b5cc0d0772 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -484,6 +484,7 @@ if(${TINT_BUILD_TESTS}) resolver/resolver_test.cc resolver/struct_layout_test.cc resolver/struct_storage_class_use_test.cc + resolver/type_constructor_validation_test.cc resolver/type_validation_test.cc resolver/validation_test.cc scope_stack_test.cc diff --git a/src/ast/variable.cc b/src/ast/variable.cc index 25a11f9ea6..793cb974ad 100644 --- a/src/ast/variable.cc +++ b/src/ast/variable.cc @@ -38,7 +38,8 @@ Variable::Variable(const Source& source, decorations_(std::move(decorations)), declared_storage_class_(declared_storage_class) { TINT_ASSERT(symbol_.IsValid()); - TINT_ASSERT(declared_type_); + // no type means we must have a constructor to infer it + TINT_ASSERT(declared_type_ || constructor); } Variable::Variable(Variable&&) = default; diff --git a/src/program_builder.cc b/src/program_builder.cc index ac5d0d1af4..bebb8be701 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -81,6 +81,44 @@ type::Type* ProgramBuilder::TypeOf(ast::Expression* expr) const { return sem ? sem->Type() : nullptr; } +ast::ConstructorExpression* ProgramBuilder::ConstructValueFilledWith( + type::Type* type, + int elem_value) { + auto* unwrapped_type = type->UnwrapAliasIfNeeded(); + if (unwrapped_type->Is()) { + return create( + create(type, elem_value == 0 ? false : true)); + } + if (unwrapped_type->Is()) { + return create(create( + type, static_cast(elem_value))); + } + if (unwrapped_type->Is()) { + return create(create( + type, static_cast(elem_value))); + } + if (unwrapped_type->Is()) { + return create(create( + type, static_cast(elem_value))); + } + if (auto* v = unwrapped_type->As()) { + auto* elem_default_value = ConstructValueFilledWith(v->type(), elem_value); + ast::ExpressionList el(v->size()); + std::fill(el.begin(), el.end(), elem_default_value); + return create(type, std::move(el)); + } + if (auto* m = unwrapped_type->As()) { + auto* col_vec_type = create(m->type(), m->rows()); + auto* vec_default_value = + ConstructValueFilledWith(col_vec_type, elem_value); + ast::ExpressionList el(m->columns()); + std::fill(el.begin(), el.end(), vec_default_value); + return create(type, std::move(el)); + } + TINT_ASSERT(false); + return nullptr; +} + ProgramBuilder::TypesBuilder::TypesBuilder(ProgramBuilder* pb) : builder(pb) {} ast::VariableDeclStatement* ProgramBuilder::WrapInStatement(ast::Variable* v) { diff --git a/src/program_builder.h b/src/program_builder.h index 12b9d41af7..76c823dff2 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -449,12 +449,20 @@ class ProgramBuilder { type); } + /// @return the tint AST pointer to `type` with the given ast::StorageClass + /// @param type the type of the pointer + /// @param storage_class the storage class of the pointer + type::Pointer* pointer(type::Type* type, + ast::StorageClass storage_class) const { + return builder->create(type, storage_class); + } + /// @return the tint AST pointer to type `T` with the given /// ast::StorageClass. /// @param storage_class the storage class of the pointer template type::Pointer* pointer(ast::StorageClass storage_class) const { - return builder->create(Of(), storage_class); + return pointer(Of(), storage_class); } /// @param name the struct name @@ -619,6 +627,17 @@ class ProgramBuilder { type, ExprList(std::forward(args)...)); } + /// Creates a constructor expression that constructs an object of + /// `type` filled with `elem_value`. For example, + /// ConstructValueFilledWith(ty.mat3x4(), 5) returns a + /// TypeConstructorExpression for a Mat3x4 filled with 5.0f values. + /// @param type the type to construct + /// @param elem_value the initial or element value (for vec and mat) to + /// construct with + /// @return the constructor expression + ast::ConstructorExpression* ConstructValueFilledWith(type::Type* type, + int elem_value = 0); + /// @param args the arguments for the vector constructor /// @return an `ast::TypeConstructorExpression` of a 2-element vector of type /// `T`, constructed with the values `args`. diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index f09849c72f..38ed168da2 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1220,6 +1220,11 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { } auto* rhs_type = TypeOf(ctor); + // If the variable has no type, infer it from the rhs + if (type == nullptr) { + type = rhs_type->UnwrapPtrIfNeeded(); + } + if (!IsValidAssignment(type, rhs_type)) { diagnostics_.add_error( "variable of type '" + type->FriendlyName(builder_->Symbols()) + diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index ac59bff017..3aa69a5c8f 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -52,29 +52,6 @@ using u32 = ProgramBuilder::u32; using f32 = ProgramBuilder::f32; using Op = ast::BinaryOp; -type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) { - return ty.bool_(); -} -type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) { - return ty.i32(); -} -type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) { - return ty.u32(); -} -type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) { - return ty.f32(); -} - -template -type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) { - return ty.vec3(); -} - -template -type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) { - return ty.mat3x3(); -} - TEST_F(ResolverTest, Stmt_Assign) { auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction); auto* lhs = Expr("v"); @@ -1015,9 +992,6 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) { namespace ExprBinaryTest { -using create_type_func_ptr = - type::Type* (*)(const ProgramBuilder::TypesBuilder& ty); - struct Params { ast::BinaryOp op; create_type_func_ptr create_lhs_type; diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index c06b0a38c0..57d820d838 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -77,6 +77,38 @@ template class ResolverTestWithParam : public TestHelper, public testing::TestWithParam {}; +inline type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) { + return ty.bool_(); +} +inline type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) { + return ty.i32(); +} +inline type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) { + return ty.u32(); +} +inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) { + return ty.f32(); +} + +template +type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec3(); +} + +template +type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x3(); +} + +using create_type_func_ptr = + type::Type* (*)(const ProgramBuilder::TypesBuilder& ty); + +template +type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) { + auto* type = create_type(ty); + return ty.alias("alias_" + type->type_name(), type); +} + } // namespace resolver } // namespace tint diff --git a/src/resolver/type_constructor_validation_test.cc b/src/resolver/type_constructor_validation_test.cc new file mode 100644 index 0000000000..2952d7ba0e --- /dev/null +++ b/src/resolver/type_constructor_validation_test.cc @@ -0,0 +1,211 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/resolver/resolver_test_helper.h" + +namespace tint { +namespace resolver { +namespace { + +/// @return the element type of `type` for vec and mat, otherwise `type` itself +type::Type* ElementTypeOf(type::Type* type) { + if (auto* v = type->As()) { + return v->type(); + } + if (auto* m = type->As()) { + return m->type(); + } + return type; +} + +class ResolverTypeConstructorValidationTest : public resolver::TestHelper, + public testing::Test {}; + +namespace InferTypeTest { +struct Params { + create_type_func_ptr create_rhs_type; +}; + +// Helpers and typedefs +using i32 = ProgramBuilder::i32; +using u32 = ProgramBuilder::u32; +using f32 = ProgramBuilder::f32; + +TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) { + // var a = 1; + // var b = a; + auto sc = ast::StorageClass::kFunction; + auto* a = Var("a", nullptr, sc, Expr(1)); + auto* b = Var("b", nullptr, sc, Expr("a")); + auto* a_ident = Expr("a"); + auto* b_ident = Expr("b"); + + WrapInFunction(Decl(a), Decl(b), Assign(a_ident, a_ident), + Assign(b_ident, b_ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + ASSERT_EQ(TypeOf(a_ident), ty.pointer(ty.i32(), sc)); + ASSERT_EQ(TypeOf(b_ident), ty.pointer(ty.i32(), sc)); +} + +using InferTypeTest_FromConstructorExpression = ResolverTestWithParam; +TEST_P(InferTypeTest_FromConstructorExpression, All) { + // e.g. for vec3 + // { + // var a = vec3(0.0, 0.0, 0.0) + // } + auto& params = GetParam(); + + auto* rhs_type = params.create_rhs_type(ty); + auto* constructor_expr = ConstructValueFilledWith(rhs_type, 0); + + auto sc = ast::StorageClass::kFunction; + auto* a = Var("a", nullptr, sc, constructor_expr); + // Self-assign 'a' to force the expression to be resolved so we can test its + // type below + auto* a_ident = Expr("a"); + WrapInFunction(Decl(a), Assign(a_ident, a_ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc)); +} + +static constexpr Params from_constructor_expression_cases[] = { + Params{ty_bool_}, + Params{ty_i32}, + Params{ty_u32}, + Params{ty_f32}, + Params{ty_vec3}, + Params{ty_vec3}, + Params{ty_vec3}, + Params{ty_mat3x3}, + Params{ty_mat3x3}, + Params{ty_mat3x3}, + Params{ty_alias}, + Params{ty_alias}, + Params{ty_alias}, + Params{ty_alias}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, +}; +INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, + InferTypeTest_FromConstructorExpression, + testing::ValuesIn(from_constructor_expression_cases)); + +using InferTypeTest_FromArithmeticExpression = ResolverTestWithParam; +TEST_P(InferTypeTest_FromArithmeticExpression, All) { + // e.g. for vec3 + // { + // var a = vec3(2.0, 2.0, 2.0) * 3.0; + // } + auto& params = GetParam(); + + auto* rhs_type = params.create_rhs_type(ty); + + auto* arith_lhs_expr = ConstructValueFilledWith(rhs_type, 2); + auto* arith_rhs_expr = ConstructValueFilledWith(ElementTypeOf(rhs_type), 3); + auto* constructor_expr = Mul(arith_lhs_expr, arith_rhs_expr); + + auto sc = ast::StorageClass::kFunction; + auto* a = Var("a", nullptr, sc, constructor_expr); + // Self-assign 'a' to force the expression to be resolved so we can test its + // type below + auto* a_ident = Expr("a"); + WrapInFunction(Decl(a), Assign(a_ident, a_ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc)); +} +static constexpr Params from_arithmetic_expression_cases[] = { + Params{ty_i32}, Params{ty_u32}, Params{ty_f32}, + Params{ty_vec3}, Params{ty_mat3x3}, + + // TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed + // Params{ty_alias}, + // Params{ty_alias}, + // Params{ty_alias}, + // Params{ty_alias>}, + // Params{ty_alias>}, + +}; +INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, + InferTypeTest_FromArithmeticExpression, + testing::ValuesIn(from_arithmetic_expression_cases)); + +using InferTypeTest_FromCallExpression = ResolverTestWithParam; +TEST_P(InferTypeTest_FromCallExpression, All) { + // e.g. for vec3 + // + // fn foo() -> vec3 { + // return vec3(0.0, 0.0, 0.0); + // } + // + // fn bar() -> void + // { + // var a = foo(); + // } + auto& params = GetParam(); + + auto* rhs_type = params.create_rhs_type(ty); + + Func("foo", {}, rhs_type, {Return(ConstructValueFilledWith(rhs_type, 0))}, + {}); + auto* constructor_expr = Call(Expr("foo")); + + auto sc = ast::StorageClass::kFunction; + auto* a = Var("a", nullptr, sc, constructor_expr); + // Self-assign 'a' to force the expression to be resolved so we can test its + // type below + auto* a_ident = Expr("a"); + WrapInFunction(Decl(a), Assign(a_ident, a_ident)); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc)); +} +static constexpr Params from_call_expression_cases[] = { + Params{ty_bool_}, + Params{ty_i32}, + Params{ty_u32}, + Params{ty_f32}, + Params{ty_vec3}, + Params{ty_vec3}, + Params{ty_vec3}, + Params{ty_mat3x3}, + Params{ty_mat3x3}, + Params{ty_mat3x3}, + Params{ty_alias}, + Params{ty_alias}, + Params{ty_alias}, + Params{ty_alias}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + Params{ty_alias>}, + +}; +INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, + InferTypeTest_FromCallExpression, + testing::ValuesIn(from_call_expression_cases)); + +} // namespace InferTypeTest + +} // namespace +} // namespace resolver +} // namespace tint diff --git a/test/BUILD.gn b/test/BUILD.gn index 9ec26a1d3a..c7f99146a3 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -182,6 +182,7 @@ source_set("tint_unittests_core_src") { "../src/resolver/resolver_test_helper.h", "../src/resolver/struct_layout_test.cc", "../src/resolver/struct_storage_class_use_test.cc", + "../src/resolver/type_constructor_validation_test.cc", "../src/resolver/type_validation_test.cc", "../src/resolver/validation_test.cc", "../src/scope_stack_test.cc",