writer/spirv: Simplify member accesses

Using semantic info.

Change-Id: Iec9a592d9d66930535ead78fab69a6085a57a941
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50302
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-05-07 21:05:54 +00:00 committed by Commit Bot service account
parent 4cd5eea87e
commit 2e1a284cbb
5 changed files with 92 additions and 123 deletions

View File

@ -2535,7 +2535,8 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
offset = utils::RoundUp(align, offset); offset = utils::RoundUp(align, offset);
auto* sem_member = builder_->create<sem::StructMember>( auto* sem_member = builder_->create<sem::StructMember>(
member, const_cast<sem::Type*>(type), offset, align, size); member, const_cast<sem::Type*>(type),
static_cast<uint32_t>(sem_members.size()), offset, align, size);
builder_->Sem().Add(member, sem_member); builder_->Sem().Add(member, sem_member);
sem_members.emplace_back(sem_member); sem_members.emplace_back(sem_member);

View File

@ -901,6 +901,7 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>(); auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
ASSERT_NE(sma, nullptr); ASSERT_NE(sma, nullptr);
EXPECT_EQ(sma->Member()->Type(), ty.f32()); EXPECT_EQ(sma->Member()->Type(), ty.f32());
EXPECT_EQ(sma->Member()->Index(), 1u);
EXPECT_EQ(sma->Member()->Declaration()->symbol(), EXPECT_EQ(sma->Member()->Declaration()->symbol(),
Symbols().Get("second_member")); Symbols().Get("second_member"));
} }
@ -925,6 +926,7 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>(); auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
ASSERT_NE(sma, nullptr); ASSERT_NE(sma, nullptr);
EXPECT_EQ(sma->Member()->Type(), ty.f32()); EXPECT_EQ(sma->Member()->Type(), ty.f32());
EXPECT_EQ(sma->Member()->Index(), 1u);
} }
TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) { TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {

View File

@ -56,11 +56,13 @@ std::string Struct::FriendlyName(const SymbolTable& symbols) const {
StructMember::StructMember(ast::StructMember* declaration, StructMember::StructMember(ast::StructMember* declaration,
sem::Type* type, sem::Type* type,
uint32_t index,
uint32_t offset, uint32_t offset,
uint32_t align, uint32_t align,
uint32_t size) uint32_t size)
: declaration_(declaration), : declaration_(declaration),
type_(type), type_(type),
index_(index),
offset_(offset), offset_(offset),
align_(align), align_(align),
size_(size) {} size_(size) {}

View File

@ -166,11 +166,13 @@ class StructMember : public Castable<StructMember, Node> {
/// Constructor /// Constructor
/// @param declaration the AST declaration node /// @param declaration the AST declaration node
/// @param type the type of the member /// @param type the type of the member
/// @param index the index of the member in the structure
/// @param offset the byte offset from the base of the structure /// @param offset the byte offset from the base of the structure
/// @param align the byte alignment of the member /// @param align the byte alignment of the member
/// @param size the byte size of the member /// @param size the byte size of the member
StructMember(ast::StructMember* declaration, StructMember(ast::StructMember* declaration,
sem::Type* type, sem::Type* type,
uint32_t index,
uint32_t offset, uint32_t offset,
uint32_t align, uint32_t align,
uint32_t size); uint32_t size);
@ -184,6 +186,9 @@ class StructMember : public Castable<StructMember, Node> {
/// @returns the type of the member /// @returns the type of the member
sem::Type* Type() const { return type_; } sem::Type* Type() const { return type_; }
/// @returns the member index
uint32_t Index() const { return index_; }
/// @returns byte offset from base of structure /// @returns byte offset from base of structure
uint32_t Offset() const { return offset_; } uint32_t Offset() const { return offset_; }
@ -196,9 +201,10 @@ class StructMember : public Castable<StructMember, Node> {
private: private:
ast::StructMember* const declaration_; ast::StructMember* const declaration_;
sem::Type* const type_; sem::Type* const type_;
uint32_t const offset_; // Byte offset from base of structure uint32_t const index_;
uint32_t const align_; // Byte alignment of the member uint32_t const offset_;
uint32_t const size_; // Byte size of the member uint32_t const align_;
uint32_t const size_;
}; };
} // namespace sem } // namespace sem

View File

@ -26,6 +26,7 @@
#include "src/sem/depth_texture_type.h" #include "src/sem/depth_texture_type.h"
#include "src/sem/function.h" #include "src/sem/function.h"
#include "src/sem/intrinsic.h" #include "src/sem/intrinsic.h"
#include "src/sem/member_accessor_expression.h"
#include "src/sem/multisampled_texture_type.h" #include "src/sem/multisampled_texture_type.h"
#include "src/sem/sampled_texture_type.h" #include "src/sem/sampled_texture_type.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
@ -89,24 +90,6 @@ bool LastIsTerminator(const ast::BlockStatement* stmts) {
last->Is<ast::FallthroughStatement>(); last->Is<ast::FallthroughStatement>();
} }
uint32_t IndexFromName(char name) {
switch (name) {
case 'x':
case 'r':
return 0;
case 'y':
case 'g':
return 1;
case 'z':
case 'b':
return 2;
case 'w':
case 'a':
return 3;
}
return std::numeric_limits<uint32_t>::max();
}
/// Returns the matrix type that is `type` or that is wrapped by /// Returns the matrix type that is `type` or that is wrapped by
/// one or more levels of an arrays inside of `type`. /// one or more levels of an arrays inside of `type`.
/// @param type the given type, which must not be null /// @param type the given type, which must not be null
@ -880,23 +863,11 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info) { AccessorInfo* info) {
auto* data_type = auto* expr_sem = builder_.Sem().Get(expr);
TypeOf(expr->structure())->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); auto* expr_type = expr_sem->Type();
auto* expr_type = TypeOf(expr);
// If the data_type is a structure we're accessing a member, if it's a if (auto* access = expr_sem->As<sem::StructMemberAccess>()) {
// vector we're accessing a swizzle. uint32_t idx = access->Member()->Index();
if (auto* str = data_type->As<sem::Struct>()) {
auto* impl = str->Declaration();
auto symbol = expr->member()->symbol();
uint32_t idx = 0;
for (; idx < impl->members().size(); ++idx) {
auto* member = impl->members()[idx];
if (member->symbol() == symbol) {
break;
}
}
if (info->source_type->Is<sem::Pointer>()) { if (info->source_type->Is<sem::Pointer>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx)); auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
@ -927,23 +898,12 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
return true; return true;
} }
if (!data_type->Is<sem::Vector>()) { if (auto* swizzle = expr_sem->As<sem::Swizzle>()) {
error_ = "Member accessor without a struct or vector. Something is wrong";
return false;
}
// TODO(dsinclair): Swizzle stuff
auto swiz = builder_.Symbols().NameFor(expr->member()->symbol());
// Single element swizzle is either an access chain or a composite extract // Single element swizzle is either an access chain or a composite extract
if (swiz.size() == 1) { auto& indices = swizzle->Indices();
auto val = IndexFromName(swiz[0]); if (indices.size() == 1) {
if (val == std::numeric_limits<uint32_t>::max()) {
error_ = "invalid swizzle name: " + swiz;
return false;
}
if (info->source_type->Is<sem::Pointer>()) { if (info->source_type->Is<sem::Pointer>()) {
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val)); auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0]));
if (idx_id == 0) { if (idx_id == 0) {
return 0; return 0;
} }
@ -959,7 +919,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
if (!push_function_inst( if (!push_function_inst(
spv::Op::OpCompositeExtract, spv::Op::OpCompositeExtract,
{Operand::Int(result_type_id), extract, {Operand::Int(result_type_id), extract,
Operand::Int(info->source_id), Operand::Int(val)})) { Operand::Int(info->source_id), Operand::Int(indices[0])})) {
return false; return false;
} }
@ -972,8 +932,8 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
// Store the type away as it may change if we run the access chain // Store the type away as it may change if we run the access chain
auto* incoming_type = info->source_type; auto* incoming_type = info->source_type;
// Multi-item extract is a VectorShuffle. We have to emit any existing access // Multi-item extract is a VectorShuffle. We have to emit any existing
// chain data, then load the access chain and shuffle that. // access chain data, then load the access chain and shuffle that.
if (!info->access_chain_indices.empty()) { if (!info->access_chain_indices.empty()) {
auto result_type_id = GenerateTypeIfNeeded(info->source_type); auto result_type_id = GenerateTypeIfNeeded(info->source_type);
if (result_type_id == 0) { if (result_type_id == 0) {
@ -1007,17 +967,11 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
OperandList ops = {Operand::Int(result_type_id), result, Operand::Int(vec_id), OperandList ops = {Operand::Int(result_type_id), result,
Operand::Int(vec_id)}; Operand::Int(vec_id), Operand::Int(vec_id)};
for (uint32_t i = 0; i < swiz.size(); ++i) { for (auto idx : indices) {
auto val = IndexFromName(swiz[i]); ops.push_back(Operand::Int(idx));
if (val == std::numeric_limits<uint32_t>::max()) {
error_ = "invalid swizzle name: " + swiz;
return false;
}
ops.push_back(Operand::Int(val));
} }
if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) { if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
@ -1025,10 +979,14 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
} }
info->source_id = result_id; info->source_id = result_id;
info->source_type = expr_type; info->source_type = expr_type;
return true; return true;
} }
TINT_ICE(builder_.Diagnostics())
<< "unhandled member index type: " << expr_sem->TypeInfo().name;
return false;
}
uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
if (!expr->IsAnyOf<ast::ArrayAccessorExpression, if (!expr->IsAnyOf<ast::ArrayAccessorExpression,
ast::MemberAccessorExpression>()) { ast::MemberAccessorExpression>()) {