diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 7a55ac7e9a..147063239d 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -978,13 +978,16 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { } info.source_type = source->result_type(); - // If our initial access in into an array, and that array is not a pointer, - // then we need to load that array into a variable in order to be access - // chain into the array + // If our initial access is into an array of non-scalar types, and that array + // is not a pointer, then we need to load that array into a variable in order + // to access chain into the array. if (accessors[0]->IsArrayAccessor()) { auto* ary_res_type = accessors[0]->AsArrayAccessor()->array()->result_type(); - if (!ary_res_type->IsPointer()) { + + if (!ary_res_type->IsPointer() && + (ary_res_type->IsArray() && + !ary_res_type->AsArray()->type()->is_scalar())) { ast::type::PointerType ptr(ary_res_type, ast::StorageClass::kFunction); auto result_type_id = GenerateTypeIfNeeded(&ptr); if (result_type_id == 0) { diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc index f0e5b0bd2f..5da3d17eb1 100644 --- a/src/writer/spirv/builder_accessor_expression_test.cc +++ b/src/writer/spirv/builder_accessor_expression_test.cc @@ -946,6 +946,51 @@ TEST_F(BuilderTest, Accessor_Array_Of_Vec) { )"); } +TEST_F(BuilderTest, Accessor_Const_Vec) { + // const pos : vec2 = vec2(0.0, 0.5); + // pos[1] + + ast::type::F32Type f32; + ast::type::U32Type u32; + ast::type::VectorType vec(&f32, 2); + + ast::ExpressionList vec_params; + vec_params.push_back(create( + create(&f32, 0.0))); + vec_params.push_back(create( + create(&f32, 0.5))); + + ast::Variable var("pos", ast::StorageClass::kPrivate, &vec); + var.set_is_const(true); + var.set_constructor( + create(&vec, std::move(vec_params))); + + ast::ArrayAccessorExpression expr(create("pos"), + create( + create(&u32, 1))); + + td.RegisterVariableForTesting(&var); + ASSERT_TRUE(td.DetermineResultType(var.constructor())) << td.error(); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); + EXPECT_EQ(b.GenerateAccessorExpression(&expr), 8u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 2 +%3 = OpConstant %2 0 +%4 = OpConstant %2 0.5 +%5 = OpConstantComposite %1 %3 %4 +%6 = OpTypeInt 32 0 +%7 = OpConstant %6 1 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), ""); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%8 = OpVectorExtractDynamic %2 %5 %7 +)"); +} + TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer) { // const a : array; // a[2]