Add type determination for member accessor.
This Cl adds the member accessor type determination for both structures and vector swizzles. Bug: tint:5 Change-Id: I1172db29d8cbed2d9e0ae228ebc3a818d4930b7f Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18846 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
4e8079544a
commit
8ee1d22882
|
@ -48,7 +48,7 @@ enum class Relation {
|
|||
kModulo,
|
||||
};
|
||||
|
||||
/// A Relational Expression
|
||||
/// An xor expression
|
||||
class RelationalExpression : public Expression {
|
||||
public:
|
||||
/// Constructor
|
||||
|
|
|
@ -28,12 +28,14 @@
|
|||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/ast/if_statement.h"
|
||||
#include "src/ast/loop_statement.h"
|
||||
#include "src/ast/member_accessor_expression.h"
|
||||
#include "src/ast/regardless_statement.h"
|
||||
#include "src/ast/return_statement.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/switch_statement.h"
|
||||
#include "src/ast/type/array_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/struct_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
#include "src/ast/unless_statement.h"
|
||||
|
@ -41,7 +43,10 @@
|
|||
|
||||
namespace tint {
|
||||
|
||||
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {}
|
||||
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {
|
||||
// TODO(dsinclair): Temporary usage to avoid compiler warning
|
||||
static_cast<void>(ctx_.type_mgr());
|
||||
}
|
||||
|
||||
TypeDeterminer::~TypeDeterminer() = default;
|
||||
|
||||
|
@ -174,15 +179,6 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) {
|
||||
for (const auto& expr : exprs) {
|
||||
if (!DetermineResultType(expr.get())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
||||
// This is blindly called above, so in some cases the expression won't exist.
|
||||
if (!expr) {
|
||||
|
@ -207,6 +203,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
|||
if (expr->IsIdentifier()) {
|
||||
return DetermineIdentifier(expr->AsIdentifier());
|
||||
}
|
||||
if (expr->IsMemberAccessor()) {
|
||||
return DetermineMemberAccessor(expr->AsMemberAccessor());
|
||||
}
|
||||
|
||||
error_ = "unknown expression for type determination";
|
||||
return false;
|
||||
|
@ -242,9 +241,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
|
|||
if (!DetermineResultType(expr->func())) {
|
||||
return false;
|
||||
}
|
||||
if (!DetermineResultType(expr->params())) {
|
||||
return false;
|
||||
}
|
||||
expr->set_result_type(expr->func()->result_type());
|
||||
return true;
|
||||
}
|
||||
|
@ -283,7 +279,45 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
|
|||
return true;
|
||||
}
|
||||
|
||||
error_ = "unknown identifier for type determination";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineMemberAccessor(
|
||||
ast::MemberAccessorExpression* expr) {
|
||||
if (!DetermineResultType(expr->structure())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto data_type = expr->structure()->result_type();
|
||||
if (data_type->IsStruct()) {
|
||||
auto strct = data_type->AsStruct()->impl();
|
||||
auto name = expr->member()->name()[0];
|
||||
|
||||
for (const auto& member : strct->members()) {
|
||||
if (member->name() != name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
expr->set_result_type(member->type());
|
||||
return true;
|
||||
}
|
||||
|
||||
error_ = "struct member not found";
|
||||
return false;
|
||||
}
|
||||
if (data_type->IsVector()) {
|
||||
auto vec = data_type->AsVector();
|
||||
|
||||
// The vector will have a number of components equal to the length of the
|
||||
// swizzle. This assumes the validator will check that the swizzle
|
||||
// is correct.
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
vec->type(), expr->member()->name()[0].size())));
|
||||
return true;
|
||||
}
|
||||
|
||||
error_ = "invalid type in member accessor";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ class CallExpression;
|
|||
class CastExpression;
|
||||
class ConstructorExpression;
|
||||
class IdentifierExpression;
|
||||
class MemberAccessorExpression;
|
||||
class Function;
|
||||
class Variable;
|
||||
|
||||
|
@ -67,10 +68,6 @@ class TypeDeterminer {
|
|||
/// @param stmt the statement to check
|
||||
/// @returns true if the determination was successful
|
||||
bool DetermineResultType(ast::Statement* stmt);
|
||||
/// Determines type information for a list of expressions
|
||||
/// @param exprs the expressions to check
|
||||
/// @returns true if the determination was successful
|
||||
bool DetermineResultType(const ast::ExpressionList& exprs);
|
||||
/// Determines type information for an expression
|
||||
/// @param expr the expression to check
|
||||
/// @returns true if the determination was successful
|
||||
|
@ -83,6 +80,8 @@ class TypeDeterminer {
|
|||
bool DetermineCast(ast::CastExpression* expr);
|
||||
bool DetermineConstructor(ast::ConstructorExpression* expr);
|
||||
bool DetermineIdentifier(ast::IdentifierExpression* expr);
|
||||
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
|
||||
|
||||
Context& ctx_;
|
||||
std::string error_;
|
||||
ScopeStack<ast::Variable*> variable_stack_;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -32,14 +33,18 @@
|
|||
#include "src/ast/if_statement.h"
|
||||
#include "src/ast/int_literal.h"
|
||||
#include "src/ast/loop_statement.h"
|
||||
#include "src/ast/member_accessor_expression.h"
|
||||
#include "src/ast/regardless_statement.h"
|
||||
#include "src/ast/return_statement.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/struct.h"
|
||||
#include "src/ast/struct_member.h"
|
||||
#include "src/ast/switch_statement.h"
|
||||
#include "src/ast/type/array_type.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/struct_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
#include "src/ast/unless_statement.h"
|
||||
|
@ -512,45 +517,6 @@ TEST_F(TypeDeterminerTest, Expr_Call) {
|
|||
EXPECT_TRUE(call.result_type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::I32Type i32;
|
||||
|
||||
ast::VariableList params;
|
||||
params.push_back(
|
||||
std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &f32));
|
||||
params.push_back(
|
||||
std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
|
||||
|
||||
auto func =
|
||||
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
|
||||
ast::Module m;
|
||||
m.AddFunction(std::move(func));
|
||||
|
||||
// Register the function
|
||||
EXPECT_TRUE(td()->Determine(&m));
|
||||
|
||||
ast::ExpressionList call_params;
|
||||
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 2.5f)));
|
||||
auto a_ptr = call_params.back().get();
|
||||
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||
auto b_ptr = call_params.back().get();
|
||||
|
||||
ast::CallExpression call(
|
||||
std::make_unique<ast::IdentifierExpression>("my_func"),
|
||||
std::move(call_params));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&call));
|
||||
ASSERT_NE(call.result_type(), nullptr);
|
||||
EXPECT_TRUE(call.result_type()->IsF32());
|
||||
|
||||
ASSERT_NE(a_ptr->result_type(), nullptr);
|
||||
EXPECT_TRUE(a_ptr->result_type()->IsF32());
|
||||
ASSERT_NE(b_ptr->result_type(), nullptr);
|
||||
EXPECT_TRUE(b_ptr->result_type()->IsI32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Cast) {
|
||||
ast::type::F32Type f32;
|
||||
ast::CastExpression cast(&f32,
|
||||
|
@ -651,5 +617,143 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
|
|||
EXPECT_TRUE(ident.result_type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
|
||||
ast::type::I32Type i32;
|
||||
ast::type::F32Type f32;
|
||||
|
||||
ast::StructMemberDecorationList decos;
|
||||
ast::StructMemberList members;
|
||||
members.push_back(std::make_unique<ast::StructMember>("first_member", &i32,
|
||||
std::move(decos)));
|
||||
members.push_back(std::make_unique<ast::StructMember>("second_member", &f32,
|
||||
std::move(decos)));
|
||||
|
||||
auto strct = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||
std::move(members));
|
||||
|
||||
ast::type::StructType st(std::move(strct));
|
||||
|
||||
auto var = std::make_unique<ast::Variable>("my_struct",
|
||||
ast::StorageClass::kNone, &st);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
EXPECT_TRUE(td()->Determine(&m));
|
||||
|
||||
auto ident = std::make_unique<ast::IdentifierExpression>("my_struct");
|
||||
auto mem_ident = std::make_unique<ast::IdentifierExpression>("second_member");
|
||||
|
||||
ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem));
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
EXPECT_TRUE(mem.result_type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>("my_vec", ast::StorageClass::kNone,
|
||||
&vec3);
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
EXPECT_TRUE(td()->Determine(&m));
|
||||
|
||||
auto ident = std::make_unique<ast::IdentifierExpression>("my_vec");
|
||||
auto swizzle = std::make_unique<ast::IdentifierExpression>("xy");
|
||||
|
||||
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
ASSERT_TRUE(mem.result_type()->IsVector());
|
||||
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
||||
// struct b {
|
||||
// vec4<f32> foo
|
||||
// }
|
||||
// struct A {
|
||||
// vec3<struct b> mem
|
||||
// }
|
||||
// var c : A
|
||||
// c.mem[0].foo.yx
|
||||
// -> vec2<f32>
|
||||
//
|
||||
// MemberAccessor{
|
||||
// MemberAccessor{
|
||||
// ArrayAccessor{
|
||||
// MemberAccessor{
|
||||
// Identifier{c}
|
||||
// Identifier{mem}
|
||||
// }
|
||||
// ScalarConstructor{0}
|
||||
// }
|
||||
// Identifier{foo}
|
||||
// }
|
||||
// Identifier{yx}
|
||||
// }
|
||||
//
|
||||
ast::type::I32Type i32;
|
||||
ast::type::F32Type f32;
|
||||
|
||||
ast::type::VectorType vec4(&f32, 4);
|
||||
|
||||
ast::StructMemberDecorationList decos;
|
||||
ast::StructMemberList b_members;
|
||||
b_members.push_back(
|
||||
std::make_unique<ast::StructMember>("foo", &vec4, std::move(decos)));
|
||||
|
||||
auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||
std::move(b_members));
|
||||
ast::type::StructType stB(std::move(strctB));
|
||||
|
||||
ast::type::VectorType vecB(&stB, 3);
|
||||
|
||||
ast::StructMemberList a_members;
|
||||
a_members.push_back(
|
||||
std::make_unique<ast::StructMember>("mem", &vecB, std::move(decos)));
|
||||
|
||||
auto strctA = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||
std::move(a_members));
|
||||
|
||||
ast::type::StructType stA(std::move(strctA));
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
EXPECT_TRUE(td()->Determine(&m));
|
||||
|
||||
auto ident = std::make_unique<ast::IdentifierExpression>("c");
|
||||
auto mem_ident = std::make_unique<ast::IdentifierExpression>("mem");
|
||||
auto foo_ident = std::make_unique<ast::IdentifierExpression>("foo");
|
||||
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 0));
|
||||
auto swizzle = std::make_unique<ast::IdentifierExpression>("yx");
|
||||
|
||||
ast::MemberAccessorExpression mem(
|
||||
std::make_unique<ast::MemberAccessorExpression>(
|
||||
std::make_unique<ast::ArrayAccessorExpression>(
|
||||
std::make_unique<ast::MemberAccessorExpression>(
|
||||
std::move(ident), std::move(mem_ident)),
|
||||
std::move(idx)),
|
||||
std::move(foo_ident)),
|
||||
std::move(swizzle));
|
||||
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||
ASSERT_NE(mem.result_type(), nullptr);
|
||||
ASSERT_TRUE(mem.result_type()->IsVector());
|
||||
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint
|
||||
|
|
Loading…
Reference in New Issue