diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 488b9ad1d3..e51e1f0c48 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -17,6 +17,7 @@ #include #include "src/ast/array_accessor_expression.h" +#include "src/ast/as_expression.h" #include "src/ast/assignment_statement.h" #include "src/ast/break_statement.h" #include "src/ast/case_statement.h" @@ -183,6 +184,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { if (expr->IsArrayAccessor()) { return DetermineArrayAccessor(expr->AsArrayAccessor()); } + if (expr->IsAs()) { + return DetermineAs(expr->AsAs()); + } if (expr->IsConstructor()) { return DetermineConstructor(expr->AsConstructor()); } @@ -215,6 +219,11 @@ bool TypeDeterminer::DetermineArrayAccessor( return true; } +bool TypeDeterminer::DetermineAs(ast::AsExpression* expr) { + expr->set_result_type(expr->type()); + return true; +} + bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) { if (expr->IsTypeConstructor()) { expr->set_result_type(expr->AsTypeConstructor()->type()); diff --git a/src/type_determiner.h b/src/type_determiner.h index 770a1bd4f1..844c81e590 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -26,6 +26,7 @@ namespace tint { namespace ast { class ArrayAccessorExpression; +class AsExpression; class ConstructorExpression; class IdentifierExpression; class Function; @@ -71,6 +72,7 @@ class TypeDeterminer { private: bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr); + bool DetermineAs(ast::AsExpression* expr); bool DetermineConstructor(ast::ConstructorExpression* expr); bool DetermineIdentifier(ast::IdentifierExpression* expr); Context& ctx_; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 70dcd76086..84f536ec8c 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -19,6 +19,7 @@ #include "gtest/gtest.h" #include "src/ast/array_accessor_expression.h" +#include "src/ast/as_expression.h" #include "src/ast/assignment_statement.h" #include "src/ast/break_statement.h" #include "src/ast/case_statement.h" @@ -478,6 +479,16 @@ TEST_F(TypeDeterminerTest, Expr_ArrayAccessor_Vector) { EXPECT_TRUE(acc.result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Expr_As) { + ast::type::F32Type f32; + ast::AsExpression as(&f32, + std::make_unique("name")); + + EXPECT_TRUE(td()->DetermineResultType(&as)); + ASSERT_NE(as.result_type(), nullptr); + EXPECT_TRUE(as.result_type()->IsF32()); +} + TEST_F(TypeDeterminerTest, Expr_Constructor_Scalar) { ast::type::F32Type f32; ast::ScalarConstructorExpression s(