tint: Fix transform errors when calling arrayLength() as a statement

Bug: chromium:1360925
Change-Id: If60fa4bb1cf4981c10dd15f8814c0aed70c0066e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101780
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-09-09 20:42:29 +00:00 committed by Dawn LUCI CQ
parent f313c48030
commit 4b70776aed
12 changed files with 254 additions and 67 deletions

View File

@ -21,6 +21,7 @@
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/simplify_pointers.h"
@ -33,65 +34,79 @@ namespace tint::transform {
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
/// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables.
/// @param ctx the CloneContext.
/// @param functor of type void(const ast::CallExpression*, const
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
/// ast::CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the
/// sem::GlobalVariable for the storage buffer.
template <typename F>
static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
auto& sem = ctx.src->Sem();
/// The PIMPL state for this transform
struct ArrayLengthFromUniform::State {
/// The clone context
CloneContext& ctx;
// Find all calls to the arrayLength() builtin.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
if (!call_expr) {
continue;
}
/// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables.
/// @param functor of type void(const ast::CallExpression*, const
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
/// ast::CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the
/// sem::GlobalVariable for the storage buffer.
template <typename F>
void IterateArrayLengthOnStorageVar(F&& functor) {
auto& sem = ctx.src->Sem();
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
continue;
}
// Find all calls to the arrayLength() builtin.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
if (!call_expr) {
continue;
}
// Get the storage buffer that contains the runtime array.
// Since we require SimplifyPointers, we can assume that the arrayLength()
// call has one of two forms:
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
if (!param || param->op != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
auto* storage_buffer_expr = param->expr;
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
storage_buffer_expr = accessor->structure;
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
continue;
}
// Get the index to use for the buffer size array.
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
if (!var) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "storage buffer is not a global variable";
break;
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
if (call_stmt->expr == call_expr) {
// arrayLength() is used as a statement.
// The argument expression must be side-effect free, so just drop the statement.
RemoveStatement(ctx, call_stmt);
continue;
}
}
// Get the storage buffer that contains the runtime array.
// Since we require SimplifyPointers, we can assume that the arrayLength()
// call has one of two forms:
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
if (!param || param->op != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
auto* storage_buffer_expr = param->expr;
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
storage_buffer_expr = accessor->structure;
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
// Get the index to use for the buffer size array.
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
if (!var) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "storage buffer is not a global variable";
break;
}
functor(call_expr, storage_buffer_sem, var);
}
functor(call_expr, storage_buffer_sem, var);
}
}
};
bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
@ -119,17 +134,17 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataM
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
State{ctx}.IterateArrayLengthOnStorageVar(
[&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
@ -156,9 +171,9 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataM
std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) {
State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {

View File

@ -113,6 +113,10 @@ class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Tra
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
/// The PIMPL state for this transform
struct State;
};
} // namespace tint::transform

View File

@ -496,5 +496,43 @@ struct SB {
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, CallStatement) {
auto* src = R"(
struct SB {
arr : array<i32>,
}
@group(0) @binding(0) var<storage, read> a : SB;
@compute @workgroup_size(1)
fn main() {
arrayLength(&a.arr);
}
)";
auto* expect =
R"(
struct SB {
arr : array<i32>,
}
@group(0) @binding(0) var<storage, read> a : SB;
@compute @workgroup_size(1)
fn main() {
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@ -130,6 +130,16 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// We're dealing with an arrayLength() call
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
if (call_stmt->expr == call_expr) {
// arrayLength() is used as a statement.
// The argument expression must be side-effect free, so just drop the
// statement.
RemoveStatement(ctx, call_stmt);
continue;
}
}
// A runtime-sized array can only appear as the store type of a variable, or the
// last element of a structure (which cannot itself be nested). Given that we
// require SimplifyPointers, we can assume that the arrayLength() call has one

View File

@ -547,5 +547,37 @@ struct SB2 {
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, CallStatement) {
auto* src = R"(
struct SB {
arr : array<i32>,
}
@group(0) @binding(0) var<storage, read> a : SB;
@compute @workgroup_size(1)
fn main() {
arrayLength(&a.arr);
}
)";
auto* expect =
R"(
struct SB {
arr : array<i32>,
}
@group(0) @binding(0) var<storage, read> a : SB;
@compute @workgroup_size(1)
fn main() {
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@ -0,0 +1,7 @@
@group(0) @binding(0)
var<storage> G : array<i32>;
fn n() {
let p = &G;
arrayLength(p);
}

View File

@ -0,0 +1,9 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
ByteAddressBuffer G : register(t0, space0);
void n() {
}

View File

@ -0,0 +1,9 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
ByteAddressBuffer G : register(t0, space0);
void n() {
}

View File

@ -0,0 +1,14 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
layout(binding = 0, std430) buffer G_block_ssbo {
int inner[];
} G;
void n() {
uint(G.inner.length());
}

View File

@ -0,0 +1,6 @@
#include <metal_stdlib>
using namespace metal;
void n() {
}

View File

@ -0,0 +1,37 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 14
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %G_block "G_block"
OpMemberName %G_block 0 "inner"
OpName %G "G"
OpName %unused_entry_point "unused_entry_point"
OpName %n "n"
OpDecorate %G_block Block
OpMemberDecorate %G_block 0 Offset 0
OpDecorate %_runtimearr_int ArrayStride 4
OpDecorate %G NonWritable
OpDecorate %G DescriptorSet 0
OpDecorate %G Binding 0
%int = OpTypeInt 32 1
%_runtimearr_int = OpTypeRuntimeArray %int
%G_block = OpTypeStruct %_runtimearr_int
%_ptr_StorageBuffer_G_block = OpTypePointer StorageBuffer %G_block
%G = OpVariable %_ptr_StorageBuffer_G_block StorageBuffer
%void = OpTypeVoid
%6 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%unused_entry_point = OpFunction %void None %6
%9 = OpLabel
OpReturn
OpFunctionEnd
%n = OpFunction %void None %6
%11 = OpLabel
%12 = OpArrayLength %uint %G 0
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,6 @@
@group(0) @binding(0) var<storage> G : array<i32>;
fn n() {
let p = &(G);
arrayLength(p);
}