Add type determination for member accessor.

This Cl adds the member accessor type determination for both structures
and vector swizzles.

Bug: tint:5
Change-Id: I1172db29d8cbed2d9e0ae228ebc3a818d4930b7f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18846
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-07 16:41:33 +00:00
parent 4e8079544a
commit 8ee1d22882
4 changed files with 195 additions and 58 deletions

View File

@ -48,7 +48,7 @@ enum class Relation {
kModulo, kModulo,
}; };
/// A Relational Expression /// An xor expression
class RelationalExpression : public Expression { class RelationalExpression : public Expression {
public: public:
/// Constructor /// Constructor

View File

@ -28,12 +28,14 @@
#include "src/ast/identifier_expression.h" #include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h" #include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h" #include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h" #include "src/ast/regardless_statement.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h" #include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h" #include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h" #include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
#include "src/ast/unless_statement.h" #include "src/ast/unless_statement.h"
@ -41,7 +43,10 @@
namespace tint { namespace tint {
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {} TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {
// TODO(dsinclair): Temporary usage to avoid compiler warning
static_cast<void>(ctx_.type_mgr());
}
TypeDeterminer::~TypeDeterminer() = default; TypeDeterminer::~TypeDeterminer() = default;
@ -174,15 +179,6 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
return false; return false;
} }
bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) {
for (const auto& expr : exprs) {
if (!DetermineResultType(expr.get())) {
return false;
}
}
return true;
}
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
// This is blindly called above, so in some cases the expression won't exist. // This is blindly called above, so in some cases the expression won't exist.
if (!expr) { if (!expr) {
@ -207,6 +203,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
if (expr->IsIdentifier()) { if (expr->IsIdentifier()) {
return DetermineIdentifier(expr->AsIdentifier()); return DetermineIdentifier(expr->AsIdentifier());
} }
if (expr->IsMemberAccessor()) {
return DetermineMemberAccessor(expr->AsMemberAccessor());
}
error_ = "unknown expression for type determination"; error_ = "unknown expression for type determination";
return false; return false;
@ -242,9 +241,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
if (!DetermineResultType(expr->func())) { if (!DetermineResultType(expr->func())) {
return false; return false;
} }
if (!DetermineResultType(expr->params())) {
return false;
}
expr->set_result_type(expr->func()->result_type()); expr->set_result_type(expr->func()->result_type());
return true; return true;
} }
@ -283,7 +279,45 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
return true; return true;
} }
error_ = "unknown identifier for type determination"; return true;
}
bool TypeDeterminer::DetermineMemberAccessor(
ast::MemberAccessorExpression* expr) {
if (!DetermineResultType(expr->structure())) {
return false;
}
auto data_type = expr->structure()->result_type();
if (data_type->IsStruct()) {
auto strct = data_type->AsStruct()->impl();
auto name = expr->member()->name()[0];
for (const auto& member : strct->members()) {
if (member->name() != name) {
continue;
}
expr->set_result_type(member->type());
return true;
}
error_ = "struct member not found";
return false;
}
if (data_type->IsVector()) {
auto vec = data_type->AsVector();
// The vector will have a number of components equal to the length of the
// swizzle. This assumes the validator will check that the swizzle
// is correct.
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
vec->type(), expr->member()->name()[0].size())));
return true;
}
error_ = "invalid type in member accessor";
return false; return false;
} }

View File

@ -31,6 +31,7 @@ class CallExpression;
class CastExpression; class CastExpression;
class ConstructorExpression; class ConstructorExpression;
class IdentifierExpression; class IdentifierExpression;
class MemberAccessorExpression;
class Function; class Function;
class Variable; class Variable;
@ -67,10 +68,6 @@ class TypeDeterminer {
/// @param stmt the statement to check /// @param stmt the statement to check
/// @returns true if the determination was successful /// @returns true if the determination was successful
bool DetermineResultType(ast::Statement* stmt); bool DetermineResultType(ast::Statement* stmt);
/// Determines type information for a list of expressions
/// @param exprs the expressions to check
/// @returns true if the determination was successful
bool DetermineResultType(const ast::ExpressionList& exprs);
/// Determines type information for an expression /// Determines type information for an expression
/// @param expr the expression to check /// @param expr the expression to check
/// @returns true if the determination was successful /// @returns true if the determination was successful
@ -83,6 +80,8 @@ class TypeDeterminer {
bool DetermineCast(ast::CastExpression* expr); bool DetermineCast(ast::CastExpression* expr);
bool DetermineConstructor(ast::ConstructorExpression* expr); bool DetermineConstructor(ast::ConstructorExpression* expr);
bool DetermineIdentifier(ast::IdentifierExpression* expr); bool DetermineIdentifier(ast::IdentifierExpression* expr);
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
Context& ctx_; Context& ctx_;
std::string error_; std::string error_;
ScopeStack<ast::Variable*> variable_stack_; ScopeStack<ast::Variable*> variable_stack_;

View File

@ -1,3 +1,4 @@
// Copyright 2020 The Tint Authors. // Copyright 2020 The Tint Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
@ -32,14 +33,18 @@
#include "src/ast/if_statement.h" #include "src/ast/if_statement.h"
#include "src/ast/int_literal.h" #include "src/ast/int_literal.h"
#include "src/ast/loop_statement.h" #include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h" #include "src/ast/regardless_statement.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h" #include "src/ast/scalar_constructor_expression.h"
#include "src/ast/struct.h"
#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h" #include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h" #include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
#include "src/ast/unless_statement.h" #include "src/ast/unless_statement.h"
@ -512,45 +517,6 @@ TEST_F(TypeDeterminerTest, Expr_Call) {
EXPECT_TRUE(call.result_type()->IsF32()); EXPECT_TRUE(call.result_type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::VariableList params;
params.push_back(
std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &f32));
params.push_back(
std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
auto func =
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
ast::Module m;
m.AddFunction(std::move(func));
// Register the function
EXPECT_TRUE(td()->Determine(&m));
ast::ExpressionList call_params;
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 2.5f)));
auto a_ptr = call_params.back().get();
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1)));
auto b_ptr = call_params.back().get();
ast::CallExpression call(
std::make_unique<ast::IdentifierExpression>("my_func"),
std::move(call_params));
EXPECT_TRUE(td()->DetermineResultType(&call));
ASSERT_NE(call.result_type(), nullptr);
EXPECT_TRUE(call.result_type()->IsF32());
ASSERT_NE(a_ptr->result_type(), nullptr);
EXPECT_TRUE(a_ptr->result_type()->IsF32());
ASSERT_NE(b_ptr->result_type(), nullptr);
EXPECT_TRUE(b_ptr->result_type()->IsI32());
}
TEST_F(TypeDeterminerTest, Expr_Cast) { TEST_F(TypeDeterminerTest, Expr_Cast) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::CastExpression cast(&f32, ast::CastExpression cast(&f32,
@ -651,5 +617,143 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
EXPECT_TRUE(ident.result_type()->IsF32()); EXPECT_TRUE(ident.result_type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
ast::type::I32Type i32;
ast::type::F32Type f32;
ast::StructMemberDecorationList decos;
ast::StructMemberList members;
members.push_back(std::make_unique<ast::StructMember>("first_member", &i32,
std::move(decos)));
members.push_back(std::make_unique<ast::StructMember>("second_member", &f32,
std::move(decos)));
auto strct = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(members));
ast::type::StructType st(std::move(strct));
auto var = std::make_unique<ast::Variable>("my_struct",
ast::StorageClass::kNone, &st);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine(&m));
auto ident = std::make_unique<ast::IdentifierExpression>("my_struct");
auto mem_ident = std::make_unique<ast::IdentifierExpression>("second_member");
ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
EXPECT_TRUE(td()->DetermineResultType(&mem));
ASSERT_NE(mem.result_type(), nullptr);
EXPECT_TRUE(mem.result_type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto var = std::make_unique<ast::Variable>("my_vec", ast::StorageClass::kNone,
&vec3);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine(&m));
auto ident = std::make_unique<ast::IdentifierExpression>("my_vec");
auto swizzle = std::make_unique<ast::IdentifierExpression>("xy");
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsVector());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
}
TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
// struct b {
// vec4<f32> foo
// }
// struct A {
// vec3<struct b> mem
// }
// var c : A
// c.mem[0].foo.yx
// -> vec2<f32>
//
// MemberAccessor{
// MemberAccessor{
// ArrayAccessor{
// MemberAccessor{
// Identifier{c}
// Identifier{mem}
// }
// ScalarConstructor{0}
// }
// Identifier{foo}
// }
// Identifier{yx}
// }
//
ast::type::I32Type i32;
ast::type::F32Type f32;
ast::type::VectorType vec4(&f32, 4);
ast::StructMemberDecorationList decos;
ast::StructMemberList b_members;
b_members.push_back(
std::make_unique<ast::StructMember>("foo", &vec4, std::move(decos)));
auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(b_members));
ast::type::StructType stB(std::move(strctB));
ast::type::VectorType vecB(&stB, 3);
ast::StructMemberList a_members;
a_members.push_back(
std::make_unique<ast::StructMember>("mem", &vecB, std::move(decos)));
auto strctA = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
std::move(a_members));
ast::type::StructType stA(std::move(strctA));
auto var =
std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
EXPECT_TRUE(td()->Determine(&m));
auto ident = std::make_unique<ast::IdentifierExpression>("c");
auto mem_ident = std::make_unique<ast::IdentifierExpression>("mem");
auto foo_ident = std::make_unique<ast::IdentifierExpression>("foo");
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 0));
auto swizzle = std::make_unique<ast::IdentifierExpression>("yx");
ast::MemberAccessorExpression mem(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::ArrayAccessorExpression>(
std::make_unique<ast::MemberAccessorExpression>(
std::move(ident), std::move(mem_ident)),
std::move(idx)),
std::move(foo_ident)),
std::move(swizzle));
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
ASSERT_NE(mem.result_type(), nullptr);
ASSERT_TRUE(mem.result_type()->IsVector());
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
}
} // namespace } // namespace
} // namespace tint } // namespace tint