diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 9483257edb..eee67439f6 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -520,6 +520,7 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr, // If the source is a pointer we access chain into it. if (info->source_type->IsPointer()) { info->access_chain_indices.push_back(idx_id); + info->source_type = expr->result_type(); return true; } @@ -538,6 +539,7 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr, info->source_id = extract_id; info->source_type = expr->result_type(); + return true; } @@ -616,6 +618,9 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, return true; } + // Store the type away as it may change if we run the access chain + auto* incoming_type = info->source_type; + // Multi-item extract is a VectorShuffle. We have to emit any existing access // chain data, then load the access chain and shuffle that. if (!info->access_chain_indices.empty()) { @@ -644,7 +649,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, return false; } - auto vec_id = GenerateLoadIfNeeded(info->source_type, info->source_id); + auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id); auto result = result_op(); auto result_id = result.to_i(); diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc index 3b3c75990b..4d78d1dc8f 100644 --- a/src/writer/spirv/builder_accessor_expression_test.cc +++ b/src/writer/spirv/builder_accessor_expression_test.cc @@ -232,6 +232,56 @@ TEST_F(BuilderTest, ArrayAccessor_MultiLevel) { )"); } +TEST_F(BuilderTest, Accessor_ArrayWithSwizzle) { + ast::type::I32Type i32; + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::ArrayType ary4(&vec3, 4); + + // var a : array, 4>; + // a[2].xy; + + ast::Variable var("ary", ast::StorageClass::kFunction, &ary4); + + ast::MemberAccessorExpression expr( + std::make_unique( + std::make_unique("ary"), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique("xy")); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(&var); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); + EXPECT_EQ(b.GenerateAccessorExpression(&expr), 14u); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 3 +%6 = OpTypeInt 32 0 +%7 = OpConstant %6 4 +%3 = OpTypeArray %4 %7 +%2 = OpTypePointer Function %3 +%8 = OpTypeInt 32 1 +%9 = OpConstant %8 2 +%10 = OpTypePointer Function %4 +%12 = OpTypeVector %5 2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), + R"(%1 = OpVariable %2 Function +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%11 = OpAccessChain %10 %1 %9 +%13 = OpLoad %4 %11 +%14 = OpVectorShuffle %12 %13 %13 0 1 +)"); +} + TEST_F(BuilderTest, MemberAccessor) { ast::type::F32Type f32; @@ -847,7 +897,7 @@ TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) { b.push_function(Function{}); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); - EXPECT_EQ(b.GenerateAccessorExpression(&expr), 18u); + EXPECT_EQ(b.GenerateAccessorExpression(&expr), 19u); EXPECT_EQ(DumpInstructions(b.types()), R"(%9 = OpTypeFloat 32 %8 = OpTypeVector %9 3 @@ -870,7 +920,8 @@ TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) { )"); EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%16 = OpAccessChain %15 %1 %14 %14 %12 %14 %14 -%18 = OpVectorShuffle %17 %16 %16 1 0 +%18 = OpLoad %8 %16 +%19 = OpVectorShuffle %17 %18 %18 1 0 )"); }