diff --git a/src/type_determiner.cc b/src/type_determiner.cc index b5b9ff1fc7..bb25fbf532 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -19,6 +19,7 @@ #include "src/ast/case_statement.h" #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" +#include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" #include "src/ast/loop_statement.h" #include "src/ast/regardless_statement.h" @@ -176,6 +177,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { if (expr->IsConstructor()) { return DetermineConstructor(expr->AsConstructor()); } + if (expr->IsIdentifier()) { + return DetermineIdentifier(expr->AsIdentifier()); + } error_ = "unknown expression for type determination"; return false; @@ -190,4 +194,28 @@ bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) { return true; } +bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) { + if (expr->name().size() > 1) { + // TODO(dsinclair): Handle imports + error_ = "imports not handled in type determination"; + return false; + } + + auto name = expr->name()[0]; + ast::Variable* var; + if (variable_stack_.get(name, &var)) { + expr->set_result_type(var->type()); + return true; + } + + auto iter = name_to_function_.find(name); + if (iter != name_to_function_.end()) { + expr->set_result_type(iter->second->return_type()); + return true; + } + + error_ = "unknown identifier for type determination"; + return false; +} + } // namespace tint diff --git a/src/type_determiner.h b/src/type_determiner.h index f8d8388ac4..03ad01d7e4 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -26,6 +26,7 @@ namespace tint { namespace ast { class ConstructorExpression; +class IdentifierExpression; class Function; class Variable; @@ -69,7 +70,7 @@ class TypeDeterminer { private: bool DetermineConstructor(ast::ConstructorExpression* expr); - + bool DetermineIdentifier(ast::IdentifierExpression* expr); Context& ctx_; std::string error_; ScopeStack variable_stack_; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index da8c7dff9b..c4fec3a694 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -24,6 +24,7 @@ #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" #include "src/ast/float_literal.h" +#include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" #include "src/ast/int_literal.h" #include "src/ast/loop_statement.h" @@ -406,5 +407,64 @@ TEST_F(TypeDeterminerTest, Expr_Constructor_Type) { EXPECT_EQ(tc.result_type()->AsVector()->size(), 3); } +TEST_F(TypeDeterminerTest, Expr_Identifier_GlobalVariable) { + ast::type::F32Type f32; + + ast::Module m; + auto var = + std::make_unique("my_var", ast::StorageClass::kNone, &f32); + m.AddGlobalVariable(std::move(var)); + + // Register the global + EXPECT_TRUE(td()->Determine(&m)); + + ast::IdentifierExpression ident("my_var"); + EXPECT_TRUE(td()->DetermineResultType(&ident)); + ASSERT_NE(ident.result_type(), nullptr); + EXPECT_TRUE(ident.result_type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { + ast::type::F32Type f32; + + auto my_var = std::make_unique("my_var"); + auto my_var_ptr = my_var.get(); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("my_var", ast::StorageClass::kNone, + &f32))); + + body.push_back(std::make_unique( + std::move(my_var), + std::make_unique("my_var"))); + + ast::Function f("my_func", {}, &f32); + f.set_body(std::move(body)); + + EXPECT_TRUE(td()->DetermineFunction(&f)); + + ASSERT_NE(my_var_ptr->result_type(), nullptr); + EXPECT_TRUE(my_var_ptr->result_type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { + ast::type::F32Type f32; + + ast::VariableList params; + auto func = + std::make_unique("my_func", std::move(params), &f32); + ast::Module m; + m.AddFunction(std::move(func)); + + // Register the function + EXPECT_TRUE(td()->Determine(&m)); + + ast::IdentifierExpression ident("my_func"); + EXPECT_TRUE(td()->DetermineResultType(&ident)); + ASSERT_NE(ident.result_type(), nullptr); + EXPECT_TRUE(ident.result_type()->IsF32()); +} + } // namespace } // namespace tint