diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 8a74ad5014..7148a78374 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -822,11 +822,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr, auto* type = TypeOf(expr->idx_expr()); idx_id = GenerateLoadIfNeeded(type, idx_id); - // If the source is a pointer we access chain into it. We also access chain - // into an array of non-scalar types. - if (info->source_type->Is() || - (info->source_type->Is() && - !info->source_type->As()->type()->is_scalar())) { + // If the source is a pointer, we access chain into it. + if (info->source_type->Is()) { info->access_chain_indices.push_back(idx_id); info->source_type = TypeOf(expr); return true; @@ -1063,17 +1060,21 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { } info.source_type = TypeOf(source); - // If our initial access is into an array of non-scalar types, and that array - // is not a pointer, then we need to load that array into a variable in order - // to access chain into the array. + // If our initial access is into a non-pointer array, and either has a + // non-scalar element type or the accessor uses a non-literal index, then we + // need to load that array into a variable in order to access chain into it. + // TODO(jrprice): The non-scalar part shouldn't be necessary, but is tied to + // how the Resolver currently determines the type of these expression. This + // should be fixed when proper support for ptr/ref types is implemented. if (auto* array = accessors[0]->As()) { - auto* ary_res_type = TypeOf(array->array()); - - if (!ary_res_type->Is() && - (ary_res_type->Is() && - !ary_res_type->As()->type()->is_scalar())) { - sem::Pointer ptr(ary_res_type, ast::StorageClass::kFunction); - auto result_type_id = GenerateTypeIfNeeded(&ptr); + auto* ary_res_type = TypeOf(array->array())->As(); + if (ary_res_type && + (!ary_res_type->type()->is_scalar() || + !array->idx_expr()->Is())) { + // Wrap the source type in a pointer to function storage. + auto ptr = + builder_.ty.pointer(ary_res_type, ast::StorageClass::kFunction); + auto result_type_id = GenerateTypeIfNeeded(ptr); if (result_type_id == 0) { return 0; } @@ -1094,6 +1095,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { } info.source_id = ary_result.to_i(); + info.source_type = ptr; } } @@ -1115,7 +1117,17 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { } if (!info.access_chain_indices.empty()) { - auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr)); + bool needs_load = false; + auto* ptr = TypeOf(expr); + if (!ptr->Is()) { + // We are performing an access chain but the final result is not a + // pointer, so we need to perform a load to get it. This happens when we + // have to copy the source expression into a function variable. + ptr = builder_.ty.pointer(ptr, ast::StorageClass::kFunction); + needs_load = true; + } + + auto result_type_id = GenerateTypeIfNeeded(ptr); if (result_type_id == 0) { return 0; } @@ -1133,6 +1145,11 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { return false; } info.source_id = result_id; + + // Load from the access chain result if required. + if (needs_load) { + info.source_id = GenerateLoadIfNeeded(ptr, result_id); + } } return info.source_id; diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc index d0c409b4a7..80819a0867 100644 --- a/src/writer/spirv/builder_accessor_expression_test.cc +++ b/src/writer/spirv/builder_accessor_expression_test.cc @@ -895,13 +895,54 @@ TEST_F(BuilderTest, Accessor_Array_NonPointer) { )"); } -TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer_Dynamic) { +TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) { // let a : array; // idx : i32 // a[idx] - // - // This needs to copy the array to an OpVariable in the Function storage class - // and then access chain into it and load the result. + + auto* var = GlobalConst("a", ty.array(), + Construct(ty.array(), 0.0f, 0.5f, 1.0f)); + + auto* idx = Var("idx", ty.i32(), ast::StorageClass::kFunction); + auto* expr = IndexAccessor("a", idx); + + ast::StatementList body; + body.push_back(WrapInStatement(idx)); + body.push_back(WrapInStatement(expr)); + WrapInFunction(body); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error(); + ASSERT_TRUE(b.GenerateFunctionVariable(idx)) << b.error(); + EXPECT_EQ(b.GenerateAccessorExpression(expr), 19u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 3 +%1 = OpTypeArray %2 %4 +%5 = OpConstant %2 0 +%6 = OpConstant %2 0.5 +%7 = OpConstant %2 1 +%8 = OpConstantComposite %1 %5 %6 %7 +%11 = OpTypeInt 32 1 +%10 = OpTypePointer Function %11 +%12 = OpConstantNull %11 +%13 = OpTypePointer Function %1 +%15 = OpConstantNull %1 +%17 = OpTypePointer Function %2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), + R"(%9 = OpVariable %10 Function %12 +%14 = OpVariable %13 Function %15 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpStore %14 %8 +%16 = OpLoad %11 %9 +%18 = OpAccessChain %17 %14 %16 +%19 = OpLoad %2 %18 +)"); } } // namespace