Add type inference support to Resolver
There is still no way to spell this out in WGSL, but this adds support for VariableDecls with an ast::Variable that has nullptr type. In this case, the Resolver uses the type of the rhs (constructor expression), which is stored in semantic::Variable. Added tests for resolving inferred types from constructor, arithmetic, and call expressions. Bug: tint:672 Change-Id: I3dcfd18adecebc8b969373d2ac72c21891c21a87 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46160 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
55bc5409c2
commit
39a65a1d1e
|
@ -484,6 +484,7 @@ if(${TINT_BUILD_TESTS})
|
|||
resolver/resolver_test.cc
|
||||
resolver/struct_layout_test.cc
|
||||
resolver/struct_storage_class_use_test.cc
|
||||
resolver/type_constructor_validation_test.cc
|
||||
resolver/type_validation_test.cc
|
||||
resolver/validation_test.cc
|
||||
scope_stack_test.cc
|
||||
|
|
|
@ -38,7 +38,8 @@ Variable::Variable(const Source& source,
|
|||
decorations_(std::move(decorations)),
|
||||
declared_storage_class_(declared_storage_class) {
|
||||
TINT_ASSERT(symbol_.IsValid());
|
||||
TINT_ASSERT(declared_type_);
|
||||
// no type means we must have a constructor to infer it
|
||||
TINT_ASSERT(declared_type_ || constructor);
|
||||
}
|
||||
|
||||
Variable::Variable(Variable&&) = default;
|
||||
|
|
|
@ -81,6 +81,44 @@ type::Type* ProgramBuilder::TypeOf(ast::Expression* expr) const {
|
|||
return sem ? sem->Type() : nullptr;
|
||||
}
|
||||
|
||||
ast::ConstructorExpression* ProgramBuilder::ConstructValueFilledWith(
|
||||
type::Type* type,
|
||||
int elem_value) {
|
||||
auto* unwrapped_type = type->UnwrapAliasIfNeeded();
|
||||
if (unwrapped_type->Is<type::Bool>()) {
|
||||
return create<ast::ScalarConstructorExpression>(
|
||||
create<ast::BoolLiteral>(type, elem_value == 0 ? false : true));
|
||||
}
|
||||
if (unwrapped_type->Is<type::I32>()) {
|
||||
return create<ast::ScalarConstructorExpression>(create<ast::SintLiteral>(
|
||||
type, static_cast<ProgramBuilder::i32>(elem_value)));
|
||||
}
|
||||
if (unwrapped_type->Is<type::U32>()) {
|
||||
return create<ast::ScalarConstructorExpression>(create<ast::UintLiteral>(
|
||||
type, static_cast<ProgramBuilder::u32>(elem_value)));
|
||||
}
|
||||
if (unwrapped_type->Is<type::F32>()) {
|
||||
return create<ast::ScalarConstructorExpression>(create<ast::FloatLiteral>(
|
||||
type, static_cast<ProgramBuilder::f32>(elem_value)));
|
||||
}
|
||||
if (auto* v = unwrapped_type->As<type::Vector>()) {
|
||||
auto* elem_default_value = ConstructValueFilledWith(v->type(), elem_value);
|
||||
ast::ExpressionList el(v->size());
|
||||
std::fill(el.begin(), el.end(), elem_default_value);
|
||||
return create<ast::TypeConstructorExpression>(type, std::move(el));
|
||||
}
|
||||
if (auto* m = unwrapped_type->As<type::Matrix>()) {
|
||||
auto* col_vec_type = create<type::Vector>(m->type(), m->rows());
|
||||
auto* vec_default_value =
|
||||
ConstructValueFilledWith(col_vec_type, elem_value);
|
||||
ast::ExpressionList el(m->columns());
|
||||
std::fill(el.begin(), el.end(), vec_default_value);
|
||||
return create<ast::TypeConstructorExpression>(type, std::move(el));
|
||||
}
|
||||
TINT_ASSERT(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ProgramBuilder::TypesBuilder::TypesBuilder(ProgramBuilder* pb) : builder(pb) {}
|
||||
|
||||
ast::VariableDeclStatement* ProgramBuilder::WrapInStatement(ast::Variable* v) {
|
||||
|
|
|
@ -449,12 +449,20 @@ class ProgramBuilder {
|
|||
type);
|
||||
}
|
||||
|
||||
/// @return the tint AST pointer to `type` with the given ast::StorageClass
|
||||
/// @param type the type of the pointer
|
||||
/// @param storage_class the storage class of the pointer
|
||||
type::Pointer* pointer(type::Type* type,
|
||||
ast::StorageClass storage_class) const {
|
||||
return builder->create<type::Pointer>(type, storage_class);
|
||||
}
|
||||
|
||||
/// @return the tint AST pointer to type `T` with the given
|
||||
/// ast::StorageClass.
|
||||
/// @param storage_class the storage class of the pointer
|
||||
template <typename T>
|
||||
type::Pointer* pointer(ast::StorageClass storage_class) const {
|
||||
return builder->create<type::Pointer>(Of<T>(), storage_class);
|
||||
return pointer(Of<T>(), storage_class);
|
||||
}
|
||||
|
||||
/// @param name the struct name
|
||||
|
@ -619,6 +627,17 @@ class ProgramBuilder {
|
|||
type, ExprList(std::forward<ARGS>(args)...));
|
||||
}
|
||||
|
||||
/// Creates a constructor expression that constructs an object of
|
||||
/// `type` filled with `elem_value`. For example,
|
||||
/// ConstructValueFilledWith(ty.mat3x4<float>(), 5) returns a
|
||||
/// TypeConstructorExpression for a Mat3x4 filled with 5.0f values.
|
||||
/// @param type the type to construct
|
||||
/// @param elem_value the initial or element value (for vec and mat) to
|
||||
/// construct with
|
||||
/// @return the constructor expression
|
||||
ast::ConstructorExpression* ConstructValueFilledWith(type::Type* type,
|
||||
int elem_value = 0);
|
||||
|
||||
/// @param args the arguments for the vector constructor
|
||||
/// @return an `ast::TypeConstructorExpression` of a 2-element vector of type
|
||||
/// `T`, constructed with the values `args`.
|
||||
|
|
|
@ -1220,6 +1220,11 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
|
|||
}
|
||||
auto* rhs_type = TypeOf(ctor);
|
||||
|
||||
// If the variable has no type, infer it from the rhs
|
||||
if (type == nullptr) {
|
||||
type = rhs_type->UnwrapPtrIfNeeded();
|
||||
}
|
||||
|
||||
if (!IsValidAssignment(type, rhs_type)) {
|
||||
diagnostics_.add_error(
|
||||
"variable of type '" + type->FriendlyName(builder_->Symbols()) +
|
||||
|
|
|
@ -52,29 +52,6 @@ using u32 = ProgramBuilder::u32;
|
|||
using f32 = ProgramBuilder::f32;
|
||||
using Op = ast::BinaryOp;
|
||||
|
||||
type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.bool_();
|
||||
}
|
||||
type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.i32();
|
||||
}
|
||||
type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.u32();
|
||||
}
|
||||
type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.f32();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.vec3<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x3<T>();
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_Assign) {
|
||||
auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction);
|
||||
auto* lhs = Expr("v");
|
||||
|
@ -1015,9 +992,6 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
|
|||
|
||||
namespace ExprBinaryTest {
|
||||
|
||||
using create_type_func_ptr =
|
||||
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
|
||||
|
||||
struct Params {
|
||||
ast::BinaryOp op;
|
||||
create_type_func_ptr create_lhs_type;
|
||||
|
|
|
@ -77,6 +77,38 @@ template <typename T>
|
|||
class ResolverTestWithParam : public TestHelper,
|
||||
public testing::TestWithParam<T> {};
|
||||
|
||||
inline type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.bool_();
|
||||
}
|
||||
inline type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.i32();
|
||||
}
|
||||
inline type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.u32();
|
||||
}
|
||||
inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.f32();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.vec3<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x3<T>();
|
||||
}
|
||||
|
||||
using create_type_func_ptr =
|
||||
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
|
||||
|
||||
template <create_type_func_ptr create_type>
|
||||
type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {
|
||||
auto* type = create_type(ty);
|
||||
return ty.alias("alias_" + type->type_name(), type);
|
||||
}
|
||||
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
||||
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/resolver/resolver_test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace resolver {
|
||||
namespace {
|
||||
|
||||
/// @return the element type of `type` for vec and mat, otherwise `type` itself
|
||||
type::Type* ElementTypeOf(type::Type* type) {
|
||||
if (auto* v = type->As<type::Vector>()) {
|
||||
return v->type();
|
||||
}
|
||||
if (auto* m = type->As<type::Matrix>()) {
|
||||
return m->type();
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
class ResolverTypeConstructorValidationTest : public resolver::TestHelper,
|
||||
public testing::Test {};
|
||||
|
||||
namespace InferTypeTest {
|
||||
struct Params {
|
||||
create_type_func_ptr create_rhs_type;
|
||||
};
|
||||
|
||||
// Helpers and typedefs
|
||||
using i32 = ProgramBuilder::i32;
|
||||
using u32 = ProgramBuilder::u32;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
|
||||
TEST_F(ResolverTypeConstructorValidationTest, InferTypeTest_Simple) {
|
||||
// var a = 1;
|
||||
// var b = a;
|
||||
auto sc = ast::StorageClass::kFunction;
|
||||
auto* a = Var("a", nullptr, sc, Expr(1));
|
||||
auto* b = Var("b", nullptr, sc, Expr("a"));
|
||||
auto* a_ident = Expr("a");
|
||||
auto* b_ident = Expr("b");
|
||||
|
||||
WrapInFunction(Decl(a), Decl(b), Assign(a_ident, a_ident),
|
||||
Assign(b_ident, b_ident));
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_EQ(TypeOf(a_ident), ty.pointer(ty.i32(), sc));
|
||||
ASSERT_EQ(TypeOf(b_ident), ty.pointer(ty.i32(), sc));
|
||||
}
|
||||
|
||||
using InferTypeTest_FromConstructorExpression = ResolverTestWithParam<Params>;
|
||||
TEST_P(InferTypeTest_FromConstructorExpression, All) {
|
||||
// e.g. for vec3<f32>
|
||||
// {
|
||||
// var a = vec3<f32>(0.0, 0.0, 0.0)
|
||||
// }
|
||||
auto& params = GetParam();
|
||||
|
||||
auto* rhs_type = params.create_rhs_type(ty);
|
||||
auto* constructor_expr = ConstructValueFilledWith(rhs_type, 0);
|
||||
|
||||
auto sc = ast::StorageClass::kFunction;
|
||||
auto* a = Var("a", nullptr, sc, constructor_expr);
|
||||
// Self-assign 'a' to force the expression to be resolved so we can test its
|
||||
// type below
|
||||
auto* a_ident = Expr("a");
|
||||
WrapInFunction(Decl(a), Assign(a_ident, a_ident));
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
|
||||
}
|
||||
|
||||
static constexpr Params from_constructor_expression_cases[] = {
|
||||
Params{ty_bool_},
|
||||
Params{ty_i32},
|
||||
Params{ty_u32},
|
||||
Params{ty_f32},
|
||||
Params{ty_vec3<i32>},
|
||||
Params{ty_vec3<u32>},
|
||||
Params{ty_vec3<f32>},
|
||||
Params{ty_mat3x3<i32>},
|
||||
Params{ty_mat3x3<u32>},
|
||||
Params{ty_mat3x3<f32>},
|
||||
Params{ty_alias<ty_bool_>},
|
||||
Params{ty_alias<ty_i32>},
|
||||
Params{ty_alias<ty_u32>},
|
||||
Params{ty_alias<ty_f32>},
|
||||
Params{ty_alias<ty_vec3<i32>>},
|
||||
Params{ty_alias<ty_vec3<u32>>},
|
||||
Params{ty_alias<ty_vec3<f32>>},
|
||||
Params{ty_alias<ty_mat3x3<i32>>},
|
||||
Params{ty_alias<ty_mat3x3<u32>>},
|
||||
Params{ty_alias<ty_mat3x3<f32>>},
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
|
||||
InferTypeTest_FromConstructorExpression,
|
||||
testing::ValuesIn(from_constructor_expression_cases));
|
||||
|
||||
using InferTypeTest_FromArithmeticExpression = ResolverTestWithParam<Params>;
|
||||
TEST_P(InferTypeTest_FromArithmeticExpression, All) {
|
||||
// e.g. for vec3<f32>
|
||||
// {
|
||||
// var a = vec3<f32>(2.0, 2.0, 2.0) * 3.0;
|
||||
// }
|
||||
auto& params = GetParam();
|
||||
|
||||
auto* rhs_type = params.create_rhs_type(ty);
|
||||
|
||||
auto* arith_lhs_expr = ConstructValueFilledWith(rhs_type, 2);
|
||||
auto* arith_rhs_expr = ConstructValueFilledWith(ElementTypeOf(rhs_type), 3);
|
||||
auto* constructor_expr = Mul(arith_lhs_expr, arith_rhs_expr);
|
||||
|
||||
auto sc = ast::StorageClass::kFunction;
|
||||
auto* a = Var("a", nullptr, sc, constructor_expr);
|
||||
// Self-assign 'a' to force the expression to be resolved so we can test its
|
||||
// type below
|
||||
auto* a_ident = Expr("a");
|
||||
WrapInFunction(Decl(a), Assign(a_ident, a_ident));
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
|
||||
}
|
||||
static constexpr Params from_arithmetic_expression_cases[] = {
|
||||
Params{ty_i32}, Params{ty_u32}, Params{ty_f32},
|
||||
Params{ty_vec3<f32>}, Params{ty_mat3x3<f32>},
|
||||
|
||||
// TODO(amaiorano): Uncomment once https://crbug.com/tint/680 is fixed
|
||||
// Params{ty_alias<ty_i32>},
|
||||
// Params{ty_alias<ty_u32>},
|
||||
// Params{ty_alias<ty_f32>},
|
||||
// Params{ty_alias<ty_vec3<f32>>},
|
||||
// Params{ty_alias<ty_mat3x3<f32>>},
|
||||
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
|
||||
InferTypeTest_FromArithmeticExpression,
|
||||
testing::ValuesIn(from_arithmetic_expression_cases));
|
||||
|
||||
using InferTypeTest_FromCallExpression = ResolverTestWithParam<Params>;
|
||||
TEST_P(InferTypeTest_FromCallExpression, All) {
|
||||
// e.g. for vec3<f32>
|
||||
//
|
||||
// fn foo() -> vec3<f32> {
|
||||
// return vec3<f32>(0.0, 0.0, 0.0);
|
||||
// }
|
||||
//
|
||||
// fn bar() -> void
|
||||
// {
|
||||
// var a = foo();
|
||||
// }
|
||||
auto& params = GetParam();
|
||||
|
||||
auto* rhs_type = params.create_rhs_type(ty);
|
||||
|
||||
Func("foo", {}, rhs_type, {Return(ConstructValueFilledWith(rhs_type, 0))},
|
||||
{});
|
||||
auto* constructor_expr = Call(Expr("foo"));
|
||||
|
||||
auto sc = ast::StorageClass::kFunction;
|
||||
auto* a = Var("a", nullptr, sc, constructor_expr);
|
||||
// Self-assign 'a' to force the expression to be resolved so we can test its
|
||||
// type below
|
||||
auto* a_ident = Expr("a");
|
||||
WrapInFunction(Decl(a), Assign(a_ident, a_ident));
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_EQ(TypeOf(a_ident), ty.pointer(rhs_type, sc));
|
||||
}
|
||||
static constexpr Params from_call_expression_cases[] = {
|
||||
Params{ty_bool_},
|
||||
Params{ty_i32},
|
||||
Params{ty_u32},
|
||||
Params{ty_f32},
|
||||
Params{ty_vec3<i32>},
|
||||
Params{ty_vec3<u32>},
|
||||
Params{ty_vec3<f32>},
|
||||
Params{ty_mat3x3<i32>},
|
||||
Params{ty_mat3x3<u32>},
|
||||
Params{ty_mat3x3<f32>},
|
||||
Params{ty_alias<ty_bool_>},
|
||||
Params{ty_alias<ty_i32>},
|
||||
Params{ty_alias<ty_u32>},
|
||||
Params{ty_alias<ty_f32>},
|
||||
Params{ty_alias<ty_vec3<i32>>},
|
||||
Params{ty_alias<ty_vec3<u32>>},
|
||||
Params{ty_alias<ty_vec3<f32>>},
|
||||
Params{ty_alias<ty_mat3x3<i32>>},
|
||||
Params{ty_alias<ty_mat3x3<u32>>},
|
||||
Params{ty_alias<ty_mat3x3<f32>>},
|
||||
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
|
||||
InferTypeTest_FromCallExpression,
|
||||
testing::ValuesIn(from_call_expression_cases));
|
||||
|
||||
} // namespace InferTypeTest
|
||||
|
||||
} // namespace
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
|
@ -182,6 +182,7 @@ source_set("tint_unittests_core_src") {
|
|||
"../src/resolver/resolver_test_helper.h",
|
||||
"../src/resolver/struct_layout_test.cc",
|
||||
"../src/resolver/struct_storage_class_use_test.cc",
|
||||
"../src/resolver/type_constructor_validation_test.cc",
|
||||
"../src/resolver/type_validation_test.cc",
|
||||
"../src/resolver/validation_test.cc",
|
||||
"../src/scope_stack_test.cc",
|
||||
|
|
Loading…
Reference in New Issue