diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 806748cc4a..3687bf9ba4 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2906,7 +2906,7 @@ TypedExpression FunctionEmitter::MakeAccessChain( std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr)); } - // All vector components are the same type, so follow the first. + // All vector components are the same type. pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0); break; case SpvOpTypeMatrix: @@ -2914,7 +2914,7 @@ TypedExpression FunctionEmitter::MakeAccessChain( next_expr = std::make_unique( std::move(current_expr.expr), std::move(MakeOperand(inst, index).expr)); - // All matrix components are the same type, so follow the first. + // All matrix components are the same type. pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0); break; case SpvOpTypeArray: @@ -2988,20 +2988,31 @@ TypedExpression FunctionEmitter::MakeCompositeExtract( static const char* swizzles[] = {"x", "y", "z", "w"}; const auto composite = inst.GetSingleWordInOperand(0); - const auto composite_type_id = def_use_mgr_->GetDef(composite)->type_id(); - const auto* current_type = type_mgr_->GetType(composite_type_id); + auto current_type_id = def_use_mgr_->GetDef(composite)->type_id(); + // Build up a nested expression for the access chain by walking down the type + // hierarchy, maintaining |current_type_id| as the SPIR-V ID of the type of + // the object pointed to after processing the previous indices. const auto num_in_operands = inst.NumInOperands(); for (uint32_t index = 1; index < num_in_operands; ++index) { const uint32_t index_val = inst.GetSingleWordInOperand(index); + + const auto* current_type_inst = def_use_mgr_->GetDef(current_type_id); + if (!current_type_inst) { + Fail() << "composite type %" << current_type_id + << " is invalid after following " << (index - 1) + << " indices: " << inst.PrettyPrint(); + return {}; + } std::unique_ptr next_expr; - switch (current_type->kind()) { - case spvtools::opt::analysis::Type::kVector: { + switch (current_type_inst->opcode()) { + case SpvOpTypeVector: { // Try generating a MemberAccessor expression. That result in something // like "foo.z", which is more idiomatic than "foo[2]". - if (current_type->AsVector()->element_count() <= index_val) { + const auto num_elems = current_type_inst->GetSingleWordInOperand(1); + if (num_elems <= index_val) { Fail() << "CompositeExtract %" << inst.result_id() << " index value " - << index_val << " is out of bounds for vector of " - << current_type->AsVector()->element_count() << " elements"; + << index_val << " is out of bounds for vector of " << num_elems + << " elements"; return {}; } if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) { @@ -3013,15 +3024,17 @@ TypedExpression FunctionEmitter::MakeCompositeExtract( std::make_unique(swizzles[index_val]); next_expr = std::make_unique( std::move(current_expr.expr), std::move(letter_index)); - current_type = current_type->AsVector()->element_type(); + // All vector components are the same type. + current_type_id = current_type_inst->GetSingleWordInOperand(0); break; } - case spvtools::opt::analysis::Type::kMatrix: + case SpvOpTypeMatrix: { // Check bounds - if (current_type->AsMatrix()->element_count() <= index_val) { + const auto num_elems = current_type_inst->GetSingleWordInOperand(1); + if (num_elems <= index_val) { Fail() << "CompositeExtract %" << inst.result_id() << " index value " - << index_val << " is out of bounds for matrix of " - << current_type->AsMatrix()->element_count() << " elements"; + << index_val << " is out of bounds for matrix of " << num_elems + << " elements"; return {}; } if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) { @@ -3032,45 +3045,44 @@ TypedExpression FunctionEmitter::MakeCompositeExtract( // Use array syntax. next_expr = std::make_unique( std::move(current_expr.expr), make_index(index_val)); - current_type = current_type->AsMatrix()->element_type(); + // All matrix components are the same type. + current_type_id = current_type_inst->GetSingleWordInOperand(0); break; - case spvtools::opt::analysis::Type::kArray: + } + case SpvOpTypeArray: // The array size could be a spec constant, and so it's not always // statically checkable. Instead, rely on a runtime index clamp // or runtime check to keep this safe. next_expr = std::make_unique( std::move(current_expr.expr), make_index(index_val)); - current_type = current_type->AsArray()->element_type(); + current_type_id = current_type_inst->GetSingleWordInOperand(0); break; - case spvtools::opt::analysis::Type::kRuntimeArray: + case SpvOpTypeRuntimeArray: Fail() << "can't do OpCompositeExtract on a runtime array"; return {}; - case spvtools::opt::analysis::Type::kStruct: { - if (current_type->AsStruct()->element_types().size() <= index_val) { + case SpvOpTypeStruct: { + const auto num_members = current_type_inst->NumInOperands(); + if (num_members <= index_val) { Fail() << "CompositeExtract %" << inst.result_id() << " index value " << index_val << " is out of bounds for structure %" - << type_mgr_->GetId(current_type) << " having " - << current_type->AsStruct()->element_types().size() - << " elements"; + << current_type_id << " having " << num_members << " members"; return {}; } - auto member_access = - std::make_unique(namer_.GetMemberName( - type_mgr_->GetId(current_type), uint32_t(index_val))); + auto member_access = std::make_unique( + namer_.GetMemberName(current_type_id, uint32_t(index_val))); next_expr = std::make_unique( std::move(current_expr.expr), std::move(member_access)); - current_type = current_type->AsStruct()->element_types()[index_val]; + current_type_id = current_type_inst->GetSingleWordInOperand(index_val); break; } default: - Fail() << "CompositeExtract with bad type %" - << type_mgr_->GetId(current_type) << " " << current_type->str(); + Fail() << "CompositeExtract with bad type %" << current_type_id << ": " + << current_type_inst->PrettyPrint(); return {}; } current_expr.reset(TypedExpression( - parser_impl_.ConvertType(type_mgr_->GetId(current_type)), - std::move(next_expr))); + parser_impl_.ConvertType(current_type_id), std::move(next_expr))); } return current_expr; } diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc index b3997018aa..f9ffd29dc3 100644 --- a/src/reader/spirv/function_composite_test.cc +++ b/src/reader/spirv/function_composite_test.cc @@ -451,6 +451,61 @@ TEST_F(SpvParserTest_CompositeExtract, Struct) { << ToString(fe.ast_body()); } +TEST_F(SpvParserTest_CompositeExtract, Struct_DifferOnlyInMemberName) { + const auto assembly = + R"( + OpMemberName %s0 0 "algo" + OpMemberName %s1 0 "rithm" +)" + Preamble() + + R"( + %s0 = OpTypeStruct %uint + %s1 = OpTypeStruct %uint + %ptr0 = OpTypePointer Function %s0 + %ptr1 = OpTypePointer Function %s1 + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %var0 = OpVariable %ptr0 Function + %var1 = OpVariable %ptr1 Function + %1 = OpLoad %s0 %var0 + %2 = OpCompositeExtract %uint %1 0 + %3 = OpLoad %s1 %var1 + %4 = OpCompositeExtract %uint %3 0 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_2 + none + __u32 + { + MemberAccessor{ + Identifier{x_1} + Identifier{algo} + } + } + })")) + << ToString(fe.ast_body()); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_4 + none + __u32 + { + MemberAccessor{ + Identifier{x_3} + Identifier{rithm} + } + } + })")) + << ToString(fe.ast_body()); +} + TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) { const auto assembly = Preamble() + R"( %ptr = OpTypePointer Function %s_v2f_u_i @@ -468,7 +523,7 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) { FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of " - "bounds for structure %25 having 3 elements")); + "bounds for structure %25 having 3 members")); } TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) {