writer/spirv: Fix dynamic array accessors

If the initial array accessor in a chain uses a non-literal index, use
the path that copies the source to a function variable, and then
perform a load from the OpAccessChain result if necessary.

Fixed: tint:426
Change-Id: Ie2f3f388170c02c1d6b73355f0b3bc49c3d3a4e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49800
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
James Price 2021-05-06 21:45:03 +00:00 committed by Commit Bot service account
parent a2580d6720
commit 698d01383c
2 changed files with 78 additions and 20 deletions

View File

@ -822,11 +822,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
auto* type = TypeOf(expr->idx_expr()); auto* type = TypeOf(expr->idx_expr());
idx_id = GenerateLoadIfNeeded(type, idx_id); idx_id = GenerateLoadIfNeeded(type, idx_id);
// If the source is a pointer we access chain into it. We also access chain // If the source is a pointer, we access chain into it.
// into an array of non-scalar types. if (info->source_type->Is<sem::Pointer>()) {
if (info->source_type->Is<sem::Pointer>() ||
(info->source_type->Is<sem::ArrayType>() &&
!info->source_type->As<sem::ArrayType>()->type()->is_scalar())) {
info->access_chain_indices.push_back(idx_id); info->access_chain_indices.push_back(idx_id);
info->source_type = TypeOf(expr); info->source_type = TypeOf(expr);
return true; return true;
@ -1063,17 +1060,21 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
} }
info.source_type = TypeOf(source); info.source_type = TypeOf(source);
// If our initial access is into an array of non-scalar types, and that array // If our initial access is into a non-pointer array, and either has a
// is not a pointer, then we need to load that array into a variable in order // non-scalar element type or the accessor uses a non-literal index, then we
// to access chain into the array. // need to load that array into a variable in order to access chain into it.
// TODO(jrprice): The non-scalar part shouldn't be necessary, but is tied to
// how the Resolver currently determines the type of these expression. This
// should be fixed when proper support for ptr/ref types is implemented.
if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) { if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) {
auto* ary_res_type = TypeOf(array->array()); auto* ary_res_type = TypeOf(array->array())->As<sem::ArrayType>();
if (ary_res_type &&
if (!ary_res_type->Is<sem::Pointer>() && (!ary_res_type->type()->is_scalar() ||
(ary_res_type->Is<sem::ArrayType>() && !array->idx_expr()->Is<ast::ScalarConstructorExpression>())) {
!ary_res_type->As<sem::ArrayType>()->type()->is_scalar())) { // Wrap the source type in a pointer to function storage.
sem::Pointer ptr(ary_res_type, ast::StorageClass::kFunction); auto ptr =
auto result_type_id = GenerateTypeIfNeeded(&ptr); builder_.ty.pointer(ary_res_type, ast::StorageClass::kFunction);
auto result_type_id = GenerateTypeIfNeeded(ptr);
if (result_type_id == 0) { if (result_type_id == 0) {
return 0; return 0;
} }
@ -1094,6 +1095,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
} }
info.source_id = ary_result.to_i(); info.source_id = ary_result.to_i();
info.source_type = ptr;
} }
} }
@ -1115,7 +1117,17 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
} }
if (!info.access_chain_indices.empty()) { if (!info.access_chain_indices.empty()) {
auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr)); bool needs_load = false;
auto* ptr = TypeOf(expr);
if (!ptr->Is<sem::Pointer>()) {
// We are performing an access chain but the final result is not a
// pointer, so we need to perform a load to get it. This happens when we
// have to copy the source expression into a function variable.
ptr = builder_.ty.pointer(ptr, ast::StorageClass::kFunction);
needs_load = true;
}
auto result_type_id = GenerateTypeIfNeeded(ptr);
if (result_type_id == 0) { if (result_type_id == 0) {
return 0; return 0;
} }
@ -1133,6 +1145,11 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
return false; return false;
} }
info.source_id = result_id; info.source_id = result_id;
// Load from the access chain result if required.
if (needs_load) {
info.source_id = GenerateLoadIfNeeded(ptr, result_id);
}
} }
return info.source_id; return info.source_id;

View File

@ -895,13 +895,54 @@ TEST_F(BuilderTest, Accessor_Array_NonPointer) {
)"); )");
} }
TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer_Dynamic) { TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) {
// let a : array<f32, 3>; // let a : array<f32, 3>;
// idx : i32 // idx : i32
// a[idx] // a[idx]
//
// This needs to copy the array to an OpVariable in the Function storage class auto* var = GlobalConst("a", ty.array<f32, 3>(),
// and then access chain into it and load the result. Construct(ty.array<f32, 3>(), 0.0f, 0.5f, 1.0f));
auto* idx = Var("idx", ty.i32(), ast::StorageClass::kFunction);
auto* expr = IndexAccessor("a", idx);
ast::StatementList body;
body.push_back(WrapInStatement(idx));
body.push_back(WrapInStatement(expr));
WrapInFunction(body);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
ASSERT_TRUE(b.GenerateFunctionVariable(idx)) << b.error();
EXPECT_EQ(b.GenerateAccessorExpression(expr), 19u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%3 = OpTypeInt 32 0
%4 = OpConstant %3 3
%1 = OpTypeArray %2 %4
%5 = OpConstant %2 0
%6 = OpConstant %2 0.5
%7 = OpConstant %2 1
%8 = OpConstantComposite %1 %5 %6 %7
%11 = OpTypeInt 32 1
%10 = OpTypePointer Function %11
%12 = OpConstantNull %11
%13 = OpTypePointer Function %1
%15 = OpConstantNull %1
%17 = OpTypePointer Function %2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
R"(%9 = OpVariable %10 Function %12
%14 = OpVariable %13 Function %15
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpStore %14 %8
%16 = OpLoad %11 %9
%18 = OpAccessChain %17 %14 %16
%19 = OpLoad %2 %18
)");
} }
} // namespace } // namespace