Add type determination for member accessor.
This Cl adds the member accessor type determination for both structures and vector swizzles. Bug: tint:5 Change-Id: I1172db29d8cbed2d9e0ae228ebc3a818d4930b7f Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18846 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
4e8079544a
commit
8ee1d22882
|
@ -48,7 +48,7 @@ enum class Relation {
|
||||||
kModulo,
|
kModulo,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A Relational Expression
|
/// An xor expression
|
||||||
class RelationalExpression : public Expression {
|
class RelationalExpression : public Expression {
|
||||||
public:
|
public:
|
||||||
/// Constructor
|
/// Constructor
|
||||||
|
|
|
@ -28,12 +28,14 @@
|
||||||
#include "src/ast/identifier_expression.h"
|
#include "src/ast/identifier_expression.h"
|
||||||
#include "src/ast/if_statement.h"
|
#include "src/ast/if_statement.h"
|
||||||
#include "src/ast/loop_statement.h"
|
#include "src/ast/loop_statement.h"
|
||||||
|
#include "src/ast/member_accessor_expression.h"
|
||||||
#include "src/ast/regardless_statement.h"
|
#include "src/ast/regardless_statement.h"
|
||||||
#include "src/ast/return_statement.h"
|
#include "src/ast/return_statement.h"
|
||||||
#include "src/ast/scalar_constructor_expression.h"
|
#include "src/ast/scalar_constructor_expression.h"
|
||||||
#include "src/ast/switch_statement.h"
|
#include "src/ast/switch_statement.h"
|
||||||
#include "src/ast/type/array_type.h"
|
#include "src/ast/type/array_type.h"
|
||||||
#include "src/ast/type/matrix_type.h"
|
#include "src/ast/type/matrix_type.h"
|
||||||
|
#include "src/ast/type/struct_type.h"
|
||||||
#include "src/ast/type/vector_type.h"
|
#include "src/ast/type/vector_type.h"
|
||||||
#include "src/ast/type_constructor_expression.h"
|
#include "src/ast/type_constructor_expression.h"
|
||||||
#include "src/ast/unless_statement.h"
|
#include "src/ast/unless_statement.h"
|
||||||
|
@ -41,7 +43,10 @@
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
|
||||||
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {}
|
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {
|
||||||
|
// TODO(dsinclair): Temporary usage to avoid compiler warning
|
||||||
|
static_cast<void>(ctx_.type_mgr());
|
||||||
|
}
|
||||||
|
|
||||||
TypeDeterminer::~TypeDeterminer() = default;
|
TypeDeterminer::~TypeDeterminer() = default;
|
||||||
|
|
||||||
|
@ -174,15 +179,6 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) {
|
|
||||||
for (const auto& expr : exprs) {
|
|
||||||
if (!DetermineResultType(expr.get())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
||||||
// This is blindly called above, so in some cases the expression won't exist.
|
// This is blindly called above, so in some cases the expression won't exist.
|
||||||
if (!expr) {
|
if (!expr) {
|
||||||
|
@ -207,6 +203,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
||||||
if (expr->IsIdentifier()) {
|
if (expr->IsIdentifier()) {
|
||||||
return DetermineIdentifier(expr->AsIdentifier());
|
return DetermineIdentifier(expr->AsIdentifier());
|
||||||
}
|
}
|
||||||
|
if (expr->IsMemberAccessor()) {
|
||||||
|
return DetermineMemberAccessor(expr->AsMemberAccessor());
|
||||||
|
}
|
||||||
|
|
||||||
error_ = "unknown expression for type determination";
|
error_ = "unknown expression for type determination";
|
||||||
return false;
|
return false;
|
||||||
|
@ -242,9 +241,6 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
|
||||||
if (!DetermineResultType(expr->func())) {
|
if (!DetermineResultType(expr->func())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!DetermineResultType(expr->params())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
expr->set_result_type(expr->func()->result_type());
|
expr->set_result_type(expr->func()->result_type());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -283,7 +279,45 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
error_ = "unknown identifier for type determination";
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TypeDeterminer::DetermineMemberAccessor(
|
||||||
|
ast::MemberAccessorExpression* expr) {
|
||||||
|
if (!DetermineResultType(expr->structure())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto data_type = expr->structure()->result_type();
|
||||||
|
if (data_type->IsStruct()) {
|
||||||
|
auto strct = data_type->AsStruct()->impl();
|
||||||
|
auto name = expr->member()->name()[0];
|
||||||
|
|
||||||
|
for (const auto& member : strct->members()) {
|
||||||
|
if (member->name() != name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
expr->set_result_type(member->type());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
error_ = "struct member not found";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (data_type->IsVector()) {
|
||||||
|
auto vec = data_type->AsVector();
|
||||||
|
|
||||||
|
// The vector will have a number of components equal to the length of the
|
||||||
|
// swizzle. This assumes the validator will check that the swizzle
|
||||||
|
// is correct.
|
||||||
|
expr->set_result_type(
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||||
|
vec->type(), expr->member()->name()[0].size())));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
error_ = "invalid type in member accessor";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ class CallExpression;
|
||||||
class CastExpression;
|
class CastExpression;
|
||||||
class ConstructorExpression;
|
class ConstructorExpression;
|
||||||
class IdentifierExpression;
|
class IdentifierExpression;
|
||||||
|
class MemberAccessorExpression;
|
||||||
class Function;
|
class Function;
|
||||||
class Variable;
|
class Variable;
|
||||||
|
|
||||||
|
@ -67,10 +68,6 @@ class TypeDeterminer {
|
||||||
/// @param stmt the statement to check
|
/// @param stmt the statement to check
|
||||||
/// @returns true if the determination was successful
|
/// @returns true if the determination was successful
|
||||||
bool DetermineResultType(ast::Statement* stmt);
|
bool DetermineResultType(ast::Statement* stmt);
|
||||||
/// Determines type information for a list of expressions
|
|
||||||
/// @param exprs the expressions to check
|
|
||||||
/// @returns true if the determination was successful
|
|
||||||
bool DetermineResultType(const ast::ExpressionList& exprs);
|
|
||||||
/// Determines type information for an expression
|
/// Determines type information for an expression
|
||||||
/// @param expr the expression to check
|
/// @param expr the expression to check
|
||||||
/// @returns true if the determination was successful
|
/// @returns true if the determination was successful
|
||||||
|
@ -83,6 +80,8 @@ class TypeDeterminer {
|
||||||
bool DetermineCast(ast::CastExpression* expr);
|
bool DetermineCast(ast::CastExpression* expr);
|
||||||
bool DetermineConstructor(ast::ConstructorExpression* expr);
|
bool DetermineConstructor(ast::ConstructorExpression* expr);
|
||||||
bool DetermineIdentifier(ast::IdentifierExpression* expr);
|
bool DetermineIdentifier(ast::IdentifierExpression* expr);
|
||||||
|
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
|
||||||
|
|
||||||
Context& ctx_;
|
Context& ctx_;
|
||||||
std::string error_;
|
std::string error_;
|
||||||
ScopeStack<ast::Variable*> variable_stack_;
|
ScopeStack<ast::Variable*> variable_stack_;
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
// Copyright 2020 The Tint Authors.
|
// Copyright 2020 The Tint Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -32,14 +33,18 @@
|
||||||
#include "src/ast/if_statement.h"
|
#include "src/ast/if_statement.h"
|
||||||
#include "src/ast/int_literal.h"
|
#include "src/ast/int_literal.h"
|
||||||
#include "src/ast/loop_statement.h"
|
#include "src/ast/loop_statement.h"
|
||||||
|
#include "src/ast/member_accessor_expression.h"
|
||||||
#include "src/ast/regardless_statement.h"
|
#include "src/ast/regardless_statement.h"
|
||||||
#include "src/ast/return_statement.h"
|
#include "src/ast/return_statement.h"
|
||||||
#include "src/ast/scalar_constructor_expression.h"
|
#include "src/ast/scalar_constructor_expression.h"
|
||||||
|
#include "src/ast/struct.h"
|
||||||
|
#include "src/ast/struct_member.h"
|
||||||
#include "src/ast/switch_statement.h"
|
#include "src/ast/switch_statement.h"
|
||||||
#include "src/ast/type/array_type.h"
|
#include "src/ast/type/array_type.h"
|
||||||
#include "src/ast/type/f32_type.h"
|
#include "src/ast/type/f32_type.h"
|
||||||
#include "src/ast/type/i32_type.h"
|
#include "src/ast/type/i32_type.h"
|
||||||
#include "src/ast/type/matrix_type.h"
|
#include "src/ast/type/matrix_type.h"
|
||||||
|
#include "src/ast/type/struct_type.h"
|
||||||
#include "src/ast/type/vector_type.h"
|
#include "src/ast/type/vector_type.h"
|
||||||
#include "src/ast/type_constructor_expression.h"
|
#include "src/ast/type_constructor_expression.h"
|
||||||
#include "src/ast/unless_statement.h"
|
#include "src/ast/unless_statement.h"
|
||||||
|
@ -512,45 +517,6 @@ TEST_F(TypeDeterminerTest, Expr_Call) {
|
||||||
EXPECT_TRUE(call.result_type()->IsF32());
|
EXPECT_TRUE(call.result_type()->IsF32());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
|
|
||||||
ast::type::F32Type f32;
|
|
||||||
ast::type::I32Type i32;
|
|
||||||
|
|
||||||
ast::VariableList params;
|
|
||||||
params.push_back(
|
|
||||||
std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &f32));
|
|
||||||
params.push_back(
|
|
||||||
std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
|
|
||||||
|
|
||||||
auto func =
|
|
||||||
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
|
|
||||||
ast::Module m;
|
|
||||||
m.AddFunction(std::move(func));
|
|
||||||
|
|
||||||
// Register the function
|
|
||||||
EXPECT_TRUE(td()->Determine(&m));
|
|
||||||
|
|
||||||
ast::ExpressionList call_params;
|
|
||||||
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
|
||||||
std::make_unique<ast::FloatLiteral>(&f32, 2.5f)));
|
|
||||||
auto a_ptr = call_params.back().get();
|
|
||||||
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
|
||||||
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
|
||||||
auto b_ptr = call_params.back().get();
|
|
||||||
|
|
||||||
ast::CallExpression call(
|
|
||||||
std::make_unique<ast::IdentifierExpression>("my_func"),
|
|
||||||
std::move(call_params));
|
|
||||||
EXPECT_TRUE(td()->DetermineResultType(&call));
|
|
||||||
ASSERT_NE(call.result_type(), nullptr);
|
|
||||||
EXPECT_TRUE(call.result_type()->IsF32());
|
|
||||||
|
|
||||||
ASSERT_NE(a_ptr->result_type(), nullptr);
|
|
||||||
EXPECT_TRUE(a_ptr->result_type()->IsF32());
|
|
||||||
ASSERT_NE(b_ptr->result_type(), nullptr);
|
|
||||||
EXPECT_TRUE(b_ptr->result_type()->IsI32());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TypeDeterminerTest, Expr_Cast) {
|
TEST_F(TypeDeterminerTest, Expr_Cast) {
|
||||||
ast::type::F32Type f32;
|
ast::type::F32Type f32;
|
||||||
ast::CastExpression cast(&f32,
|
ast::CastExpression cast(&f32,
|
||||||
|
@ -651,5 +617,143 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
|
||||||
EXPECT_TRUE(ident.result_type()->IsF32());
|
EXPECT_TRUE(ident.result_type()->IsF32());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) {
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
ast::StructMemberDecorationList decos;
|
||||||
|
ast::StructMemberList members;
|
||||||
|
members.push_back(std::make_unique<ast::StructMember>("first_member", &i32,
|
||||||
|
std::move(decos)));
|
||||||
|
members.push_back(std::make_unique<ast::StructMember>("second_member", &f32,
|
||||||
|
std::move(decos)));
|
||||||
|
|
||||||
|
auto strct = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||||
|
std::move(members));
|
||||||
|
|
||||||
|
ast::type::StructType st(std::move(strct));
|
||||||
|
|
||||||
|
auto var = std::make_unique<ast::Variable>("my_struct",
|
||||||
|
ast::StorageClass::kNone, &st);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
// Register the global
|
||||||
|
EXPECT_TRUE(td()->Determine(&m));
|
||||||
|
|
||||||
|
auto ident = std::make_unique<ast::IdentifierExpression>("my_struct");
|
||||||
|
auto mem_ident = std::make_unique<ast::IdentifierExpression>("second_member");
|
||||||
|
|
||||||
|
ast::MemberAccessorExpression mem(std::move(ident), std::move(mem_ident));
|
||||||
|
EXPECT_TRUE(td()->DetermineResultType(&mem));
|
||||||
|
ASSERT_NE(mem.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(mem.result_type()->IsF32());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_MemberAccessor_VectorSwizzle) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
|
||||||
|
auto var = std::make_unique<ast::Variable>("my_vec", ast::StorageClass::kNone,
|
||||||
|
&vec3);
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
// Register the global
|
||||||
|
EXPECT_TRUE(td()->Determine(&m));
|
||||||
|
|
||||||
|
auto ident = std::make_unique<ast::IdentifierExpression>("my_vec");
|
||||||
|
auto swizzle = std::make_unique<ast::IdentifierExpression>("xy");
|
||||||
|
|
||||||
|
ast::MemberAccessorExpression mem(std::move(ident), std::move(swizzle));
|
||||||
|
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||||
|
ASSERT_NE(mem.result_type(), nullptr);
|
||||||
|
ASSERT_TRUE(mem.result_type()->IsVector());
|
||||||
|
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
|
||||||
|
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
||||||
|
// struct b {
|
||||||
|
// vec4<f32> foo
|
||||||
|
// }
|
||||||
|
// struct A {
|
||||||
|
// vec3<struct b> mem
|
||||||
|
// }
|
||||||
|
// var c : A
|
||||||
|
// c.mem[0].foo.yx
|
||||||
|
// -> vec2<f32>
|
||||||
|
//
|
||||||
|
// MemberAccessor{
|
||||||
|
// MemberAccessor{
|
||||||
|
// ArrayAccessor{
|
||||||
|
// MemberAccessor{
|
||||||
|
// Identifier{c}
|
||||||
|
// Identifier{mem}
|
||||||
|
// }
|
||||||
|
// ScalarConstructor{0}
|
||||||
|
// }
|
||||||
|
// Identifier{foo}
|
||||||
|
// }
|
||||||
|
// Identifier{yx}
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
ast::type::VectorType vec4(&f32, 4);
|
||||||
|
|
||||||
|
ast::StructMemberDecorationList decos;
|
||||||
|
ast::StructMemberList b_members;
|
||||||
|
b_members.push_back(
|
||||||
|
std::make_unique<ast::StructMember>("foo", &vec4, std::move(decos)));
|
||||||
|
|
||||||
|
auto strctB = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||||
|
std::move(b_members));
|
||||||
|
ast::type::StructType stB(std::move(strctB));
|
||||||
|
|
||||||
|
ast::type::VectorType vecB(&stB, 3);
|
||||||
|
|
||||||
|
ast::StructMemberList a_members;
|
||||||
|
a_members.push_back(
|
||||||
|
std::make_unique<ast::StructMember>("mem", &vecB, std::move(decos)));
|
||||||
|
|
||||||
|
auto strctA = std::make_unique<ast::Struct>(ast::StructDecoration::kNone,
|
||||||
|
std::move(a_members));
|
||||||
|
|
||||||
|
ast::type::StructType stA(std::move(strctA));
|
||||||
|
|
||||||
|
auto var =
|
||||||
|
std::make_unique<ast::Variable>("c", ast::StorageClass::kNone, &stA);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
// Register the global
|
||||||
|
EXPECT_TRUE(td()->Determine(&m));
|
||||||
|
|
||||||
|
auto ident = std::make_unique<ast::IdentifierExpression>("c");
|
||||||
|
auto mem_ident = std::make_unique<ast::IdentifierExpression>("mem");
|
||||||
|
auto foo_ident = std::make_unique<ast::IdentifierExpression>("foo");
|
||||||
|
auto idx = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 0));
|
||||||
|
auto swizzle = std::make_unique<ast::IdentifierExpression>("yx");
|
||||||
|
|
||||||
|
ast::MemberAccessorExpression mem(
|
||||||
|
std::make_unique<ast::MemberAccessorExpression>(
|
||||||
|
std::make_unique<ast::ArrayAccessorExpression>(
|
||||||
|
std::make_unique<ast::MemberAccessorExpression>(
|
||||||
|
std::move(ident), std::move(mem_ident)),
|
||||||
|
std::move(idx)),
|
||||||
|
std::move(foo_ident)),
|
||||||
|
std::move(swizzle));
|
||||||
|
EXPECT_TRUE(td()->DetermineResultType(&mem)) << td()->error();
|
||||||
|
ASSERT_NE(mem.result_type(), nullptr);
|
||||||
|
ASSERT_TRUE(mem.result_type()->IsVector());
|
||||||
|
EXPECT_TRUE(mem.result_type()->AsVector()->type()->IsF32());
|
||||||
|
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
Loading…
Reference in New Issue