diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 2578e06233..806748cc4a 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2851,30 +2851,43 @@ TypedExpression FunctionEmitter::MakeAccessChain( } } - const auto* ptr_type = type_mgr_->GetType(ptr_ty_id); - if (!ptr_type || !ptr_type->AsPointer()) { + const auto* ptr_type_inst = def_use_mgr_->GetDef(ptr_ty_id); + if (!ptr_type_inst || (ptr_type_inst->opcode() != SpvOpTypePointer)) { Fail() << "Access chain %" << inst.result_id() << " base pointer is not of pointer type"; return {}; } - SpvStorageClass storage_class = ptr_type->AsPointer()->storage_class(); - const auto* pointee_type = ptr_type->AsPointer()->pointee_type(); + SpvStorageClass storage_class = + static_cast(ptr_type_inst->GetSingleWordInOperand(0)); + uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); + + // Build up a nested expression for the access chain by walking down the type + // hierarchy, maintaining |pointee_type_id| as the SPIR-V ID of the type of + // the object pointed to after processing the previous indices. for (uint32_t index = first_index; index < num_in_operands; ++index) { const auto* index_const = constants[index] ? constants[index]->AsIntConstant() : nullptr; const int64_t index_const_val = index_const ? index_const->GetSignExtendedValue() : 0; std::unique_ptr next_expr; - switch (pointee_type->kind()) { - case spvtools::opt::analysis::Type::kVector: + + const auto* pointee_type_inst = def_use_mgr_->GetDef(pointee_type_id); + if (!pointee_type_inst) { + Fail() << "pointee type %" << pointee_type_id + << " is invalid after following " << (index - first_index) + << " indices: " << inst.PrettyPrint(); + return {}; + } + switch (pointee_type_inst->opcode()) { + case SpvOpTypeVector: if (index_const) { - // Try generating a MemberAccessor expression. - if (index_const_val < 0 || - pointee_type->AsVector()->element_count() <= index_const_val) { + // Try generating a MemberAccessor expression + const auto num_elems = pointee_type_inst->GetSingleWordInOperand(1); + if (index_const_val < 0 || num_elems <= index_const_val) { Fail() << "Access chain %" << inst.result_id() << " index %" << inst.GetSingleWordInOperand(index) << " value " << index_const_val << " is out of bounds for vector of " - << pointee_type->AsVector()->element_count() << " elements"; + << num_elems << " elements"; return {}; } if (uint64_t(index_const_val) >= @@ -2893,61 +2906,58 @@ TypedExpression FunctionEmitter::MakeAccessChain( std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr)); } - pointee_type = pointee_type->AsVector()->element_type(); + // All vector components are the same type, so follow the first. + pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0); break; - case spvtools::opt::analysis::Type::kMatrix: + case SpvOpTypeMatrix: // Use array syntax. next_expr = std::make_unique( std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr)); - pointee_type = pointee_type->AsMatrix()->element_type(); + // All matrix components are the same type, so follow the first. + pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0); break; - case spvtools::opt::analysis::Type::kArray: + case SpvOpTypeArray: next_expr = std::make_unique( std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr)); - pointee_type = pointee_type->AsArray()->element_type(); + pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0); break; - case spvtools::opt::analysis::Type::kRuntimeArray: + case SpvOpTypeRuntimeArray: next_expr = std::make_unique( std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr)); - pointee_type = pointee_type->AsRuntimeArray()->element_type(); + pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0); break; - case spvtools::opt::analysis::Type::kStruct: { + case SpvOpTypeStruct: { if (!index_const) { Fail() << "Access chain %" << inst.result_id() << " index %" << inst.GetSingleWordInOperand(index) << " is a non-constant index into a structure %" - << type_mgr_->GetId(pointee_type); + << pointee_type_id; return {}; } - if ((index_const_val < 0) || - pointee_type->AsStruct()->element_types().size() <= - uint64_t(index_const_val)) { + const auto num_members = pointee_type_inst->NumInOperands(); + if ((index_const_val < 0) || num_members <= uint64_t(index_const_val)) { Fail() << "Access chain %" << inst.result_id() << " index value " << index_const_val << " is out of bounds for structure %" - << type_mgr_->GetId(pointee_type) << " having " - << pointee_type->AsStruct()->element_types().size() - << " elements"; + << pointee_type_id << " having " << num_members << " members"; return {}; } - auto member_access = - std::make_unique(namer_.GetMemberName( - type_mgr_->GetId(pointee_type), uint32_t(index_const_val))); + auto member_access = std::make_unique( + namer_.GetMemberName(pointee_type_id, uint32_t(index_const_val))); next_expr = std::make_unique( std::move(current_expr.expr), std::move(member_access)); - pointee_type = - pointee_type->AsStruct()->element_types()[index_const_val]; + pointee_type_id = pointee_type_inst->GetSingleWordInOperand( + static_cast(index_const_val)); break; } default: - Fail() << "Access chain with unknown pointee type %" - << type_mgr_->GetId(pointee_type) << " " << pointee_type->str(); + Fail() << "Access chain with unknown or invalid pointee type %" + << pointee_type_id << ": " << pointee_type_inst->PrettyPrint(); return {}; } - const auto pointee_type_id = type_mgr_->GetId(pointee_type); const auto pointer_type_id = type_mgr_->FindPointerToType(pointee_type_id, storage_class); auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); diff --git a/src/reader/spirv/function_memory_test.cc b/src/reader/spirv/function_memory_test.cc index 149b304c66..5d4a4366c6 100644 --- a/src/reader/spirv/function_memory_test.cc +++ b/src/reader/spirv/function_memory_test.cc @@ -347,7 +347,7 @@ TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorSwizzle) { Identifier{z} } ScalarConstructor{42} -})")); +})")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorConstOutOfBounds) { @@ -535,6 +535,61 @@ TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct) { })")); } +TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct_DifferOnlyMemberName) { + // The spirv-opt internal representation will map both structs to the + // same canonicalized type, because it doesn't care about member names. + // But we care about member names when producing a member-access expression. + // crbug.com/tint/213 + const std::string assembly = R"( + OpName %1 "myvar" + OpName %10 "myvar2" + OpMemberName %strct 1 "age" + OpMemberName %strct2 1 "ancientness" + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %float_42 = OpConstant %float 42 + %float_420 = OpConstant %float 420 + %strct = OpTypeStruct %float %float + %strct2 = OpTypeStruct %float %float + %elem_ty = OpTypePointer Workgroup %float + %var_ty = OpTypePointer Workgroup %strct + %var2_ty = OpTypePointer Workgroup %strct2 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + + %1 = OpVariable %var_ty Workgroup + %10 = OpVariable %var2_ty Workgroup + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %2 = OpAccessChain %elem_ty %1 %uint_1 + OpStore %2 %float_42 + %20 = OpAccessChain %elem_ty %10 %uint_1 + OpStore %20 %float_420 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) + << assembly << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{ + MemberAccessor{ + Identifier{myvar} + Identifier{age} + } + ScalarConstructor{42.000000} +} +Assignment{ + MemberAccessor{ + Identifier{myvar2} + Identifier{ancientness} + } + ScalarConstructor{420.000000} +})")) << ToString(fe.ast_body()); +} + TEST_F(SpvParserTest, EmitStatement_AccessChain_StructNonConstIndex) { const std::string assembly = R"( OpName %1 "myvar" @@ -597,7 +652,7 @@ TEST_F(SpvParserTest, EmitStatement_AccessChain_StructConstOutOfBounds) { FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("Access chain %2 index value 99 is out of bounds " - "for structure %55 having 2 elements")); + "for structure %55 having 2 members")); } TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct_RuntimeArray) { @@ -705,7 +760,8 @@ TEST_F(SpvParserTest, EmitStatement_AccessChain_InvalidPointeeType) { FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), - HasSubstr("Access chain with unknown pointee type %60 void")); + HasSubstr("Access chain with unknown or invalid pointee type " + "%60: %60 = OpTypePointer Workgroup %55")); } std::string OldStorageBufferPreamble() {