semantic: Add Function::Parameters(), MemberAccessorExpression subtypes

Add semantic::Swizzle and semantic::StructMemberAccess, both deriving from MemberAccessorExpression

Add semantic::Function::Parameters() to list the semantic::Variable parameters for the function.

Change-Id: I8cc69f3738380c14f61d051ee2989be6194d148d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47220
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-04-09 13:56:08 +00:00 committed by Commit Bot service account
parent f8f31a458f
commit e9c4984489
11 changed files with 174 additions and 48 deletions

View File

@ -534,7 +534,9 @@ bool Resolver::Function(ast::Function* func) {
variable_stack_.push_scope();
for (auto* param : func->params()) {
variable_stack_.set(param->symbol(), CreateVariableInfo(param));
auto* param_info = CreateVariableInfo(param);
variable_stack_.set(param->symbol(), param_info);
func_info->parameters.emplace_back(param_info);
if (!ApplyStorageClassUsageToType(param->declared_storage_class(),
param->declared_type(),
@ -1171,12 +1173,14 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
std::vector<uint32_t> swizzle;
if (auto* ty = data_type->As<type::Struct>()) {
auto* strct = ty->impl();
auto* str = Structure(ty);
auto symbol = expr->member()->symbol();
for (auto* member : strct->members()) {
if (member->symbol() == symbol) {
ret = member->type();
const semantic::StructMember* member = nullptr;
for (auto* m : str->members) {
if (m->Declaration()->symbol() == symbol) {
ret = m->Declaration()->type();
member = m;
break;
}
}
@ -1192,6 +1196,9 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
if (auto* ptr = res->As<type::Pointer>()) {
ret = builder_->create<type::Pointer>(ret, ptr->storage_class());
}
builder_->Sem().Add(expr, builder_->create<semantic::StructMemberAccess>(
expr, ret, current_statement_, member));
} else if (auto* vec = data_type->As<type::Vector>()) {
std::string str = builder_->Symbols().NameFor(expr->member()->symbol());
auto size = str.size();
@ -1257,6 +1264,9 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
ret = builder_->create<type::Vector>(vec->type(),
static_cast<uint32_t>(size));
}
builder_->Sem().Add(
expr, builder_->create<semantic::Swizzle>(expr, ret, current_statement_,
std::move(swizzle)));
} else {
diagnostics_.add_error(
"invalid use of member accessor on a non-vector/non-struct " +
@ -1265,9 +1275,6 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
return false;
}
builder_->Sem().Add(expr,
builder_->create<semantic::MemberAccessorExpression>(
expr, ret, current_statement_, std::move(swizzle)));
SetType(expr, ret);
return true;
@ -1682,7 +1689,8 @@ void Resolver::CreateSemanticNodes() const {
auto* info = it.second;
auto* sem_func = builder_->create<semantic::Function>(
info->declaration, remap_vars(info->referenced_module_vars),
info->declaration, remap_vars(info->parameters),
remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars), info->return_statements,
ancestor_entry_points[func->symbol()]);
func_info_to_sem_func.emplace(info, sem_func);

View File

@ -106,6 +106,7 @@ class Resolver {
~FunctionInfo();
ast::Function* const declaration;
std::vector<VariableInfo*> parameters;
UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars;
std::vector<const ast::ReturnStatement*> return_statements;

View File

@ -736,6 +736,32 @@ TEST_F(ResolverTest, Expr_Identifier_Unknown) {
EXPECT_FALSE(r()->Resolve());
}
TEST_F(ResolverTest, Function_Parameters) {
auto* param_a = Param("a", ty.f32());
auto* param_b = Param("b", ty.i32());
auto* param_c = Param("c", ty.u32());
auto* func = Func("my_func",
ast::VariableList{
param_a,
param_b,
param_c,
},
ty.void_(), {});
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->Parameters().size(), 3u);
EXPECT_EQ(func_sem->Parameters()[0]->Type(), ty.f32());
EXPECT_EQ(func_sem->Parameters()[1]->Type(), ty.i32());
EXPECT_EQ(func_sem->Parameters()[2]->Type(), ty.u32());
EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
}
TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
auto* in_var = Global("in_var", ty.f32(), ast::StorageClass::kInput);
auto* out_var = Global("out_var", ty.f32(), ast::StorageClass::kOutput);
@ -757,6 +783,7 @@ TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->Parameters().size(), 0u);
const auto& vars = func_sem->ReferencedModuleVariables();
ASSERT_EQ(vars.size(), 5u);
@ -794,6 +821,7 @@ TEST_F(ResolverTest, Function_RegisterInputOutputVariables_SubFunction) {
auto* func2_sem = Sem().Get(func2);
ASSERT_NE(func2_sem, nullptr);
EXPECT_EQ(func2_sem->Parameters().size(), 0u);
const auto& vars = func2_sem->ReferencedModuleVariables();
ASSERT_EQ(vars.size(), 5u);
@ -842,6 +870,7 @@ TEST_F(ResolverTest, Function_ReturnStatements) {
auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->Parameters().size(), 0u);
EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
@ -867,6 +896,14 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* ptr = TypeOf(mem)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
ASSERT_TRUE(Sem().Get(mem)->Is<semantic::StructMemberAccess>());
EXPECT_EQ(Sem()
.Get(mem)
->As<semantic::StructMemberAccess>()
->Member()
->Declaration()
->symbol(),
Symbols().Get("second_member"));
}
TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
@ -889,6 +926,7 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
auto* ptr = TypeOf(mem)->As<type::Pointer>();
EXPECT_TRUE(ptr->type()->Is<type::F32>());
ASSERT_TRUE(Sem().Get(mem)->Is<semantic::StructMemberAccess>());
}
TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
@ -903,7 +941,9 @@ TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 4u);
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(0, 2, 1, 3));
ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
EXPECT_THAT(Sem().Get(mem)->As<semantic::Swizzle>()->Indices(),
ElementsAre(0, 2, 1, 3));
}
TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
@ -919,7 +959,9 @@ TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
auto* ptr = TypeOf(mem)->As<type::Pointer>();
ASSERT_TRUE(ptr->type()->Is<type::F32>());
EXPECT_THAT(Sem().Get(mem)->Swizzle(), ElementsAre(2));
ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
EXPECT_THAT(Sem().Get(mem)->As<semantic::Swizzle>()->Indices(),
ElementsAre(2));
}
TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
@ -971,6 +1013,7 @@ TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
ASSERT_TRUE(TypeOf(mem)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(mem)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(mem)->As<type::Vector>()->size(), 2u);
ASSERT_TRUE(Sem().Get(mem)->Is<semantic::Swizzle>());
}
TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
@ -1502,6 +1545,10 @@ TEST_F(ResolverTest, Function_EntryPoints_StageDecoration) {
ASSERT_NE(ep_1_sem, nullptr);
ASSERT_NE(ep_2_sem, nullptr);
EXPECT_EQ(func_b_sem->Parameters().size(), 0u);
EXPECT_EQ(func_a_sem->Parameters().size(), 0u);
EXPECT_EQ(func_c_sem->Parameters().size(), 0u);
const auto& b_eps = func_b_sem->AncestorEntryPoints();
ASSERT_EQ(2u, b_eps.size());
EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]);

View File

@ -46,12 +46,14 @@ class Function : public Castable<Function, CallTarget> {
/// Constructor
/// @param declaration the ast::Function
/// @param parameters the parameters to the function
/// @param referenced_module_vars the referenced module variables
/// @param local_referenced_module_vars the locally referenced module
/// @param return_statements the function return statements
/// variables
/// @param ancestor_entry_points the ancestor entry points
Function(ast::Function* declaration,
std::vector<const Variable*> parameters,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
@ -63,6 +65,9 @@ class Function : public Castable<Function, CallTarget> {
/// @returns the ast::Function declaration
ast::Function* Declaration() const { return declaration_; }
/// @return the parameters to the function
const std::vector<const Variable*> Parameters() const { return parameters_; }
/// Note: If this function calls other functions, the return will also include
/// all of the referenced variables from the callees.
/// @returns the referenced module variables
@ -147,6 +152,7 @@ class Function : public Castable<Function, CallTarget> {
bool multisampled) const;
ast::Function* const declaration_;
std::vector<const Variable*> const parameters_;
std::vector<const Variable*> const referenced_module_vars_;
std::vector<const Variable*> const local_referenced_module_vars_;
std::vector<const ast::ReturnStatement*> const return_statements_;

View File

@ -20,8 +20,18 @@
#include "src/semantic/expression.h"
namespace tint {
/// Forward declarations
namespace ast {
class MemberAccessorExpression;
} // namespace ast
namespace semantic {
/// Forward declarations
class Struct;
class StructMember;
/// MemberAccessorExpression holds the semantic information for a
/// ast::MemberAccessorExpression node.
class MemberAccessorExpression
@ -31,24 +41,60 @@ class MemberAccessorExpression
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param statement the statement that owns this expression
/// @param swizzle if this member access is for a vector swizzle, the swizzle
/// indices
MemberAccessorExpression(ast::Expression* declaration,
MemberAccessorExpression(ast::MemberAccessorExpression* declaration,
type::Type* type,
Statement* statement,
std::vector<uint32_t> swizzle);
Statement* statement);
/// Destructor
~MemberAccessorExpression() override;
};
/// @return true if this member access is for a vector swizzle
bool IsSwizzle() const { return !swizzle_.empty(); }
/// StructMemberAccess holds the semantic information for a
/// ast::MemberAccessorExpression node that represents an access to a structure
/// member.
class StructMemberAccess
: public Castable<StructMemberAccess, MemberAccessorExpression> {
public:
/// Constructor
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param member the structure member
StructMemberAccess(ast::MemberAccessorExpression* declaration,
type::Type* type,
Statement* statement,
const StructMember* member);
/// @return the swizzle indices, if this is a vector swizzle
const std::vector<uint32_t>& Swizzle() const { return swizzle_; }
/// Destructor
~StructMemberAccess() override;
/// @returns the structure member
StructMember const* Member() const { return member_; }
private:
std::vector<uint32_t> const swizzle_;
StructMember const* const member_;
};
/// Swizzle holds the semantic information for a ast::MemberAccessorExpression
/// node that represents a vector swizzle.
class Swizzle : public Castable<Swizzle, MemberAccessorExpression> {
public:
/// Constructor
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param indices the swizzle indices
Swizzle(ast::MemberAccessorExpression* declaration,
type::Type* type,
Statement* statement,
std::vector<uint32_t> indices);
/// Destructor
~Swizzle() override;
/// @return the swizzle indices, if this is a vector swizzle
const std::vector<uint32_t>& Indices() const { return indices_; }
private:
std::vector<uint32_t> const indices_;
};
} // namespace semantic

View File

@ -41,12 +41,14 @@ ParameterList GetParameters(ast::Function* ast) {
} // namespace
Function::Function(ast::Function* declaration,
std::vector<const Variable*> parameters,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points)
: Base(declaration->return_type(), GetParameters(declaration)),
declaration_(declaration),
parameters_(std::move(parameters)),
referenced_module_vars_(std::move(referenced_module_vars)),
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
return_statements_(std::move(return_statements)),

View File

@ -12,21 +12,40 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/member_accessor_expression.h"
#include "src/semantic/member_accessor_expression.h"
TINT_INSTANTIATE_TYPEINFO(tint::semantic::MemberAccessorExpression);
TINT_INSTANTIATE_TYPEINFO(tint::semantic::StructMemberAccess);
TINT_INSTANTIATE_TYPEINFO(tint::semantic::Swizzle);
namespace tint {
namespace semantic {
MemberAccessorExpression::MemberAccessorExpression(
ast::Expression* declaration,
ast::MemberAccessorExpression* declaration,
type::Type* type,
Statement* statement,
std::vector<uint32_t> swizzle)
: Base(declaration, type, statement), swizzle_(std::move(swizzle)) {}
Statement* statement)
: Base(declaration, type, statement) {}
MemberAccessorExpression::~MemberAccessorExpression() = default;
StructMemberAccess::StructMemberAccess(
ast::MemberAccessorExpression* declaration,
type::Type* type,
Statement* statement,
const StructMember* member)
: Base(declaration, type, statement), member_(member) {}
StructMemberAccess::~StructMemberAccess() = default;
Swizzle::Swizzle(ast::MemberAccessorExpression* declaration,
type::Type* type,
Statement* statement,
std::vector<uint32_t> indices)
: Base(declaration, type, statement), indices_(std::move(indices)) {}
Swizzle::~Swizzle() = default;
} // namespace semantic
} // namespace tint

View File

@ -644,27 +644,12 @@ Transform::Output DecomposeStorageAccess::Run(const Program* in,
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
// X.Y
auto* accessor_sem = sem.Get(accessor);
auto swizzle = accessor_sem->Swizzle();
switch (swizzle.size()) {
case 0: {
if (auto access = state.TakeAccess(accessor->structure())) {
auto* str_ty = access.type->As<type::Struct>();
auto* member =
sem.Get(str_ty)->FindMember(accessor->member()->symbol());
auto offset = member->Offset();
state.AddAccesss(
accessor, {
access.var,
Add(std::move(access.offset), std::move(offset)),
member->Declaration()->type()->UnwrapAll(),
});
}
break;
}
case 1: {
if (auto* swizzle = accessor_sem->As<semantic::Swizzle>()) {
if (swizzle->Indices().size() == 1) {
if (auto access = state.TakeAccess(accessor->structure())) {
auto* vec_ty = access.type->As<type::Vector>();
auto offset = Mul(ScalarSize(vec_ty->type()), swizzle[0]);
auto offset =
Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
state.AddAccesss(
accessor, {
access.var,
@ -672,7 +657,19 @@ Transform::Output DecomposeStorageAccess::Run(const Program* in,
vec_ty->type()->UnwrapAll(),
});
}
break;
}
} else {
if (auto access = state.TakeAccess(accessor->structure())) {
auto* str_ty = access.type->As<type::Struct>();
auto* member =
sem.Get(str_ty)->FindMember(accessor->member()->symbol());
auto offset = member->Offset();
state.AddAccesss(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
member->Declaration()->type()->UnwrapAll(),
});
}
}
continue;

View File

@ -53,7 +53,7 @@ Transform::Output Renamer::Run(const Program* in, const DataMap&) {
<< "MemberAccessorExpression has no semantic info";
continue;
}
if (sem->IsSwizzle()) {
if (sem->Is<semantic::Swizzle>()) {
preserve.emplace(member->member());
}
} else if (auto* call = node->As<ast::CallExpression>()) {

View File

@ -2246,7 +2246,7 @@ bool GeneratorImpl::EmitMemberAccessor(std::ostream& pre,
out << ".";
// Swizzles output the name directly
if (builder_.Sem().Get(expr)->IsSwizzle()) {
if (builder_.Sem().Get(expr)->Is<semantic::Swizzle>()) {
out << builder_.Symbols().NameFor(expr->member()->symbol());
} else if (!EmitExpression(pre, out, expr->member())) {
return false;

View File

@ -1737,7 +1737,7 @@ bool GeneratorImpl::EmitMemberAccessor(ast::MemberAccessorExpression* expr) {
out_ << ".";
// Swizzles get written out directly
if (program_->Sem().Get(expr)->IsSwizzle()) {
if (program_->Sem().Get(expr)->Is<semantic::Swizzle>()) {
out_ << program_->Symbols().NameFor(expr->member()->symbol());
} else if (!EmitExpression(expr->member())) {
return false;