From 8ee1d2288240cabd56b4c3a417638162b4745a54 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Tue, 7 Apr 2020 16:41:33 +0000 Subject: [PATCH] Add type determination for member accessor. This Cl adds the member accessor type determination for both structures and vector swizzles. Bug: tint:5 Change-Id: I1172db29d8cbed2d9e0ae228ebc3a818d4930b7f Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18846 Reviewed-by: David Neto --- src/ast/relational_expression.h | 2 +- src/type_determiner.cc | 62 ++++++++--- src/type_determiner.h | 7 +- src/type_determiner_test.cc | 182 +++++++++++++++++++++++++------- 4 files changed, 195 insertions(+), 58 deletions(-) diff --git a/src/ast/relational_expression.h b/src/ast/relational_expression.h index 04b906414a..f5bd944329 100644 --- a/src/ast/relational_expression.h +++ b/src/ast/relational_expression.h @@ -48,7 +48,7 @@ enum class Relation { kModulo, }; -/// A Relational Expression +/// An xor expression class RelationalExpression : public Expression { public: /// Constructor diff --git a/src/type_determiner.cc b/src/type_determiner.cc index f1da09c174..9a81e19f79 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -28,12 +28,14 @@ #include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" #include "src/ast/loop_statement.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/regardless_statement.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/switch_statement.h" #include "src/ast/type/array_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/unless_statement.h" @@ -41,7 +43,10 @@ namespace tint { -TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {} +TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) { + // TODO(dsinclair): Temporary usage to avoid compiler warning + static_cast(ctx_.type_mgr()); +} TypeDeterminer::~TypeDeterminer() = default; @@ -174,15 +179,6 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { return false; } -bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) { - for (const auto& expr : exprs) { - if (!DetermineResultType(expr.get())) { - return false; - } - } - return true; -} - bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { // This is blindly called above, so in some cases the expression won't exist. if (!expr) { @@ -207,6 +203,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { if (expr->IsIdentifier()) { return DetermineIdentifier(expr->AsIdentifier()); } + if (expr->IsMemberAccessor()) { + return DetermineMemberAccessor(expr->AsMemberAccessor()); + } error_ = "unknown expression for type determination"; return false; @@ -242,9 +241,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) { if (!DetermineResultType(expr->func())) { return false; } - if (!DetermineResultType(expr->params())) { - return false; - } expr->set_result_type(expr->func()->result_type()); return true; } @@ -283,7 +279,45 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) { return true; } - error_ = "unknown identifier for type determination"; + return true; +} + +bool TypeDeterminer::DetermineMemberAccessor( + ast::MemberAccessorExpression* expr) { + if (!DetermineResultType(expr->structure())) { + return false; + } + + auto data_type = expr->structure()->result_type(); + if (data_type->IsStruct()) { + auto strct = data_type->AsStruct()->impl(); + auto name = expr->member()->name()[0]; + + for (const auto& member : strct->members()) { + if (member->name() != name) { + continue; + } + + expr->set_result_type(member->type()); + return true; + } + + error_ = "struct member not found"; + return false; + } + if (data_type->IsVector()) { + auto vec = data_type->AsVector(); + + // The vector will have a number of components equal to the length of the + // swizzle. This assumes the validator will check that the swizzle + // is correct. + expr->set_result_type( + ctx_.type_mgr().Get(std::make_unique( + vec->type(), expr->member()->name()[0].size()))); + return true; + } + + error_ = "invalid type in member accessor"; return false; } diff --git a/src/type_determiner.h b/src/type_determiner.h index 692947ead9..34fcaf35fa 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -31,6 +31,7 @@ class CallExpression; class CastExpression; class ConstructorExpression; class IdentifierExpression; +class MemberAccessorExpression; class Function; class Variable; @@ -67,10 +68,6 @@ class TypeDeterminer { /// @param stmt the statement to check /// @returns true if the determination was successful bool DetermineResultType(ast::Statement* stmt); - /// Determines type information for a list of expressions - /// @param exprs the expressions to check - /// @returns true if the determination was successful - bool DetermineResultType(const ast::ExpressionList& exprs); /// Determines type information for an expression /// @param expr the expression to check /// @returns true if the determination was successful @@ -83,6 +80,8 @@ class TypeDeterminer { bool DetermineCast(ast::CastExpression* expr); bool DetermineConstructor(ast::ConstructorExpression* expr); bool DetermineIdentifier(ast::IdentifierExpression* expr); + bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr); + Context& ctx_; std::string error_; ScopeStack variable_stack_; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 7796ca7846..486246b591 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -1,3 +1,4 @@ + // Copyright 2020 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,14 +33,18 @@ #include "src/ast/if_statement.h" #include "src/ast/int_literal.h" #include "src/ast/loop_statement.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/regardless_statement.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/struct.h" +#include "src/ast/struct_member.h" #include "src/ast/switch_statement.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/unless_statement.h" @@ -512,45 +517,6 @@ TEST_F(TypeDeterminerTest, Expr_Call) { EXPECT_TRUE(call.result_type()->IsF32()); } -TEST_F(TypeDeterminerTest, Expr_Call_WithParams) { - ast::type::F32Type f32; - ast::type::I32Type i32; - - ast::VariableList params; - params.push_back( - std::make_unique("a", ast::StorageClass::kNone, &f32)); - params.push_back( - std::make_unique("b", ast::StorageClass::kNone, &i32)); - - auto func = - std::make_unique("my_func", std::move(params), &f32); - ast::Module m; - m.AddFunction(std::move(func)); - - // Register the function - EXPECT_TRUE(td()->Determine(&m)); - - ast::ExpressionList call_params; - call_params.push_back(std::make_unique( - std::make_unique(&f32, 2.5f))); - auto a_ptr = call_params.back().get(); - call_params.push_back(std::make_unique( - std::make_unique(&i32, 1))); - auto b_ptr = call_params.back().get(); - - ast::CallExpression call( - std::make_unique("my_func"), - std::move(call_params)); - EXPECT_TRUE(td()->DetermineResultType(&call)); - ASSERT_NE(call.result_type(), nullptr); - EXPECT_TRUE(call.result_type()->IsF32()); - - ASSERT_NE(a_ptr->result_type(), nullptr); - EXPECT_TRUE(a_ptr->result_type()->IsF32()); - ASSERT_NE(b_ptr->result_type(), nullptr); - EXPECT_TRUE(b_ptr->result_type()->IsI32()); -} - TEST_F(TypeDeterminerTest, Expr_Cast) { ast::type::F32Type f32; ast::CastExpression cast(&f32, @@ -651,5 +617,143 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { EXPECT_TRUE(ident.result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back(std::make_unique("first_member", &i32, + std::move(decos))); + members.push_back(std::make_unique("second_member", &f32, + std::move(decos))); + + auto strct = std::make_unique(ast::StructDecoration::kNone, + std::move(members)); + + ast::type::StructType st(std::move(strct)); + + auto var = std::make_unique("my_struct", + ast::StorageClass::kNone, &st); + + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + // Register the global + EXPECT_TRUE(td()->Determine(&m)); + + auto ident = std::make_unique("my_struct"); + auto mem_ident = std::make_unique("second_member"); + + ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident)); + EXPECT_TRUE(td()->DetermineResultType(&mem)); + ASSERT_NE(mem.result_type(), nullptr); + EXPECT_TRUE(mem.result_type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + auto var = std::make_unique("my_vec", ast::StorageClass::kNone, + &vec3); + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + // Register the global + EXPECT_TRUE(td()->Determine(&m)); + + auto ident = std::make_unique("my_vec"); + auto swizzle = std::make_unique("xy"); + + ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle)); + EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); + ASSERT_NE(mem.result_type(), nullptr); + ASSERT_TRUE(mem.result_type()->IsVector()); + EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32()); + EXPECT_EQ(mem.result_type()->AsVector()->size(), 2); +} + +TEST_F(TypeDeterminerTest, Expr_MultiLevel) { + // struct b { + // vec4 foo + // } + // struct A { + // vec3 mem + // } + // var c : A + // c.mem[0].foo.yx + // -> vec2 + // + // MemberAccessor{ + // MemberAccessor{ + // ArrayAccessor{ + // MemberAccessor{ + // Identifier{c} + // Identifier{mem} + // } + // ScalarConstructor{0} + // } + // Identifier{foo} + // } + // Identifier{yx} + // } + // + ast::type::I32Type i32; + ast::type::F32Type f32; + + ast::type::VectorType vec4(&f32, 4); + + ast::StructMemberDecorationList decos; + ast::StructMemberList b_members; + b_members.push_back( + std::make_unique("foo", &vec4, std::move(decos))); + + auto strctB = std::make_unique(ast::StructDecoration::kNone, + std::move(b_members)); + ast::type::StructType stB(std::move(strctB)); + + ast::type::VectorType vecB(&stB, 3); + + ast::StructMemberList a_members; + a_members.push_back( + std::make_unique("mem", &vecB, std::move(decos))); + + auto strctA = std::make_unique(ast::StructDecoration::kNone, + std::move(a_members)); + + ast::type::StructType stA(std::move(strctA)); + + auto var = + std::make_unique("c", ast::StorageClass::kNone, &stA); + + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + // Register the global + EXPECT_TRUE(td()->Determine(&m)); + + auto ident = std::make_unique("c"); + auto mem_ident = std::make_unique("mem"); + auto foo_ident = std::make_unique("foo"); + auto idx = std::make_unique( + std::make_unique(&i32, 0)); + auto swizzle = std::make_unique("yx"); + + ast::MemberAccessorExpression mem( + std::make_unique( + std::make_unique( + std::make_unique( + std::move(ident), std::move(mem_ident)), + std::move(idx)), + std::move(foo_ident)), + std::move(swizzle)); + EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error(); + ASSERT_NE(mem.result_type(), nullptr); + ASSERT_TRUE(mem.result_type()->IsVector()); + EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32()); + EXPECT_EQ(mem.result_type()->AsVector()->size(), 2); +} + } // namespace } // namespace tint