diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index a3162c2fd3..78f91fa701 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2535,7 +2535,8 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { offset = utils::RoundUp(align, offset); auto* sem_member = builder_->create( - member, const_cast(type), offset, align, size); + member, const_cast(type), + static_cast(sem_members.size()), offset, align, size); builder_->Sem().Add(member, sem_member); sem_members.emplace_back(sem_member); diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index 32a900aa8e..88cfbb6d8a 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -901,6 +901,7 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Member()->Type(), ty.f32()); + EXPECT_EQ(sma->Member()->Index(), 1u); EXPECT_EQ(sma->Member()->Declaration()->symbol(), Symbols().Get("second_member")); } @@ -925,6 +926,7 @@ TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) { auto* sma = Sem().Get(mem)->As(); ASSERT_NE(sma, nullptr); EXPECT_EQ(sma->Member()->Type(), ty.f32()); + EXPECT_EQ(sma->Member()->Index(), 1u); } TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) { diff --git a/src/sem/struct.cc b/src/sem/struct.cc index 40f3e0649e..f8eefc1474 100644 --- a/src/sem/struct.cc +++ b/src/sem/struct.cc @@ -56,11 +56,13 @@ std::string Struct::FriendlyName(const SymbolTable& symbols) const { StructMember::StructMember(ast::StructMember* declaration, sem::Type* type, + uint32_t index, uint32_t offset, uint32_t align, uint32_t size) : declaration_(declaration), type_(type), + index_(index), offset_(offset), align_(align), size_(size) {} diff --git a/src/sem/struct.h b/src/sem/struct.h index fda7b4f325..6694d6d894 100644 --- a/src/sem/struct.h +++ b/src/sem/struct.h @@ -166,11 +166,13 @@ class StructMember : public Castable { /// Constructor /// @param declaration the AST declaration node /// @param type the type of the member + /// @param index the index of the member in the structure /// @param offset the byte offset from the base of the structure /// @param align the byte alignment of the member /// @param size the byte size of the member StructMember(ast::StructMember* declaration, sem::Type* type, + uint32_t index, uint32_t offset, uint32_t align, uint32_t size); @@ -184,6 +186,9 @@ class StructMember : public Castable { /// @returns the type of the member sem::Type* Type() const { return type_; } + /// @returns the member index + uint32_t Index() const { return index_; } + /// @returns byte offset from base of structure uint32_t Offset() const { return offset_; } @@ -196,9 +201,10 @@ class StructMember : public Castable { private: ast::StructMember* const declaration_; sem::Type* const type_; - uint32_t const offset_; // Byte offset from base of structure - uint32_t const align_; // Byte alignment of the member - uint32_t const size_; // Byte size of the member + uint32_t const index_; + uint32_t const offset_; + uint32_t const align_; + uint32_t const size_; }; } // namespace sem diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index dee0885a70..935387dcab 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -26,6 +26,7 @@ #include "src/sem/depth_texture_type.h" #include "src/sem/function.h" #include "src/sem/intrinsic.h" +#include "src/sem/member_accessor_expression.h" #include "src/sem/multisampled_texture_type.h" #include "src/sem/sampled_texture_type.h" #include "src/sem/struct.h" @@ -89,24 +90,6 @@ bool LastIsTerminator(const ast::BlockStatement* stmts) { last->Is(); } -uint32_t IndexFromName(char name) { - switch (name) { - case 'x': - case 'r': - return 0; - case 'y': - case 'g': - return 1; - case 'z': - case 'b': - return 2; - case 'w': - case 'a': - return 3; - } - return std::numeric_limits::max(); -} - /// Returns the matrix type that is `type` or that is wrapped by /// one or more levels of an arrays inside of `type`. /// @param type the given type, which must not be null @@ -880,23 +863,11 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr, bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, AccessorInfo* info) { - auto* data_type = - TypeOf(expr->structure())->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); - auto* expr_type = TypeOf(expr); + auto* expr_sem = builder_.Sem().Get(expr); + auto* expr_type = expr_sem->Type(); - // If the data_type is a structure we're accessing a member, if it's a - // vector we're accessing a swizzle. - if (auto* str = data_type->As()) { - auto* impl = str->Declaration(); - auto symbol = expr->member()->symbol(); - - uint32_t idx = 0; - for (; idx < impl->members().size(); ++idx) { - auto* member = impl->members()[idx]; - if (member->symbol() == symbol) { - break; - } - } + if (auto* access = expr_sem->As()) { + uint32_t idx = access->Member()->Index(); if (info->source_type->Is()) { auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx)); @@ -927,106 +898,93 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, return true; } - if (!data_type->Is()) { - error_ = "Member accessor without a struct or vector. Something is wrong"; - return false; - } + if (auto* swizzle = expr_sem->As()) { + // Single element swizzle is either an access chain or a composite extract + auto& indices = swizzle->Indices(); + if (indices.size() == 1) { + if (info->source_type->Is()) { + auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0])); + if (idx_id == 0) { + return 0; + } + info->access_chain_indices.push_back(idx_id); + } else { + auto result_type_id = GenerateTypeIfNeeded(expr_type); + if (result_type_id == 0) { + return 0; + } - // TODO(dsinclair): Swizzle stuff - auto swiz = builder_.Symbols().NameFor(expr->member()->symbol()); - // Single element swizzle is either an access chain or a composite extract - if (swiz.size() == 1) { - auto val = IndexFromName(swiz[0]); - if (val == std::numeric_limits::max()) { - error_ = "invalid swizzle name: " + swiz; - return false; + auto extract = result_op(); + auto extract_id = extract.to_i(); + if (!push_function_inst( + spv::Op::OpCompositeExtract, + {Operand::Int(result_type_id), extract, + Operand::Int(info->source_id), Operand::Int(indices[0])})) { + return false; + } + + info->source_id = extract_id; + info->source_type = expr_type; + } + return true; } - if (info->source_type->Is()) { - auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val)); - if (idx_id == 0) { - return 0; - } - info->access_chain_indices.push_back(idx_id); - } else { - auto result_type_id = GenerateTypeIfNeeded(expr_type); + // Store the type away as it may change if we run the access chain + auto* incoming_type = info->source_type; + + // Multi-item extract is a VectorShuffle. We have to emit any existing + // access chain data, then load the access chain and shuffle that. + if (!info->access_chain_indices.empty()) { + auto result_type_id = GenerateTypeIfNeeded(info->source_type); if (result_type_id == 0) { return 0; } - auto extract = result_op(); auto extract_id = extract.to_i(); - if (!push_function_inst( - spv::Op::OpCompositeExtract, - {Operand::Int(result_type_id), extract, - Operand::Int(info->source_id), Operand::Int(val)})) { + + OperandList ops = {Operand::Int(result_type_id), extract, + Operand::Int(info->source_id)}; + for (auto id : info->access_chain_indices) { + ops.push_back(Operand::Int(id)); + } + + if (!push_function_inst(spv::Op::OpAccessChain, ops)) { return false; } - info->source_id = extract_id; - info->source_type = expr_type; + info->source_id = GenerateLoadIfNeeded(expr_type, extract_id); + info->source_type = expr_type->UnwrapPtrIfNeeded(); + info->access_chain_indices.clear(); } + + auto result_type_id = GenerateTypeIfNeeded(expr_type); + if (result_type_id == 0) { + return false; + } + + auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id); + + auto result = result_op(); + auto result_id = result.to_i(); + + OperandList ops = {Operand::Int(result_type_id), result, + Operand::Int(vec_id), Operand::Int(vec_id)}; + + for (auto idx : indices) { + ops.push_back(Operand::Int(idx)); + } + + if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) { + return false; + } + info->source_id = result_id; + info->source_type = expr_type; return true; } - // Store the type away as it may change if we run the access chain - auto* incoming_type = info->source_type; - - // Multi-item extract is a VectorShuffle. We have to emit any existing access - // chain data, then load the access chain and shuffle that. - if (!info->access_chain_indices.empty()) { - auto result_type_id = GenerateTypeIfNeeded(info->source_type); - if (result_type_id == 0) { - return 0; - } - auto extract = result_op(); - auto extract_id = extract.to_i(); - - OperandList ops = {Operand::Int(result_type_id), extract, - Operand::Int(info->source_id)}; - for (auto id : info->access_chain_indices) { - ops.push_back(Operand::Int(id)); - } - - if (!push_function_inst(spv::Op::OpAccessChain, ops)) { - return false; - } - - info->source_id = GenerateLoadIfNeeded(expr_type, extract_id); - info->source_type = expr_type->UnwrapPtrIfNeeded(); - info->access_chain_indices.clear(); - } - - auto result_type_id = GenerateTypeIfNeeded(expr_type); - if (result_type_id == 0) { - return false; - } - - auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id); - - auto result = result_op(); - auto result_id = result.to_i(); - - OperandList ops = {Operand::Int(result_type_id), result, Operand::Int(vec_id), - Operand::Int(vec_id)}; - - for (uint32_t i = 0; i < swiz.size(); ++i) { - auto val = IndexFromName(swiz[i]); - if (val == std::numeric_limits::max()) { - error_ = "invalid swizzle name: " + swiz; - return false; - } - - ops.push_back(Operand::Int(val)); - } - - if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) { - return false; - } - info->source_id = result_id; - info->source_type = expr_type; - - return true; + TINT_ICE(builder_.Diagnostics()) + << "unhandled member index type: " << expr_sem->TypeInfo().name; + return false; } uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {