tint: Fix transform errors when calling arrayLength() as a statement
Bug: chromium:1360925 Change-Id: If60fa4bb1cf4981c10dd15f8814c0aed70c0066e Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101780 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
f313c48030
commit
4b70776aed
|
@ -21,6 +21,7 @@
|
|||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
|
||||
|
@ -33,65 +34,79 @@ namespace tint::transform {
|
|||
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
|
||||
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
|
||||
|
||||
/// Iterate over all arrayLength() builtins that operate on
|
||||
/// storage buffer variables.
|
||||
/// @param ctx the CloneContext.
|
||||
/// @param functor of type void(const ast::CallExpression*, const
|
||||
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
|
||||
/// ast::CallExpression of the arrayLength call expression node, a
|
||||
/// sem::VariableUser of the used storage buffer variable, and the
|
||||
/// sem::GlobalVariable for the storage buffer.
|
||||
template <typename F>
|
||||
static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
|
||||
auto& sem = ctx.src->Sem();
|
||||
/// The PIMPL state for this transform
|
||||
struct ArrayLengthFromUniform::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
|
||||
// Find all calls to the arrayLength() builtin.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
auto* call_expr = node->As<ast::CallExpression>();
|
||||
if (!call_expr) {
|
||||
continue;
|
||||
}
|
||||
/// Iterate over all arrayLength() builtins that operate on
|
||||
/// storage buffer variables.
|
||||
/// @param functor of type void(const ast::CallExpression*, const
|
||||
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
|
||||
/// ast::CallExpression of the arrayLength call expression node, a
|
||||
/// sem::VariableUser of the used storage buffer variable, and the
|
||||
/// sem::GlobalVariable for the storage buffer.
|
||||
template <typename F>
|
||||
void IterateArrayLengthOnStorageVar(F&& functor) {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* builtin = call->Target()->As<sem::Builtin>();
|
||||
if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
|
||||
continue;
|
||||
}
|
||||
// Find all calls to the arrayLength() builtin.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
auto* call_expr = node->As<ast::CallExpression>();
|
||||
if (!call_expr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the storage buffer that contains the runtime array.
|
||||
// Since we require SimplifyPointers, we can assume that the arrayLength()
|
||||
// call has one of two forms:
|
||||
// arrayLength(&struct_var.array_member)
|
||||
// arrayLength(&array_var)
|
||||
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
|
||||
if (!param || param->op != ast::UnaryOp::kAddressOf) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
auto* storage_buffer_expr = param->expr;
|
||||
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
|
||||
storage_buffer_expr = accessor->structure;
|
||||
}
|
||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||
if (!storage_buffer_sem) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* builtin = call->Target()->As<sem::Builtin>();
|
||||
if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the index to use for the buffer size array.
|
||||
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
|
||||
if (!var) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "storage buffer is not a global variable";
|
||||
break;
|
||||
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
|
||||
if (call_stmt->expr == call_expr) {
|
||||
// arrayLength() is used as a statement.
|
||||
// The argument expression must be side-effect free, so just drop the statement.
|
||||
RemoveStatement(ctx, call_stmt);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the storage buffer that contains the runtime array.
|
||||
// Since we require SimplifyPointers, we can assume that the arrayLength()
|
||||
// call has one of two forms:
|
||||
// arrayLength(&struct_var.array_member)
|
||||
// arrayLength(&array_var)
|
||||
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
|
||||
if (!param || param->op != ast::UnaryOp::kAddressOf) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
auto* storage_buffer_expr = param->expr;
|
||||
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
|
||||
storage_buffer_expr = accessor->structure;
|
||||
}
|
||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||
if (!storage_buffer_sem) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
|
||||
// Get the index to use for the buffer size array.
|
||||
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
|
||||
if (!var) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "storage buffer is not a global variable";
|
||||
break;
|
||||
}
|
||||
functor(call_expr, storage_buffer_sem, var);
|
||||
}
|
||||
functor(call_expr, storage_buffer_sem, var);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
|
@ -119,17 +134,17 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataM
|
|||
// Determine the size of the buffer size array.
|
||||
uint32_t max_buffer_size_index = 0;
|
||||
|
||||
IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
if (idx_itr->second > max_buffer_size_index) {
|
||||
max_buffer_size_index = idx_itr->second;
|
||||
}
|
||||
});
|
||||
State{ctx}.IterateArrayLengthOnStorageVar(
|
||||
[&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
if (idx_itr->second > max_buffer_size_index) {
|
||||
max_buffer_size_index = idx_itr->second;
|
||||
}
|
||||
});
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// size of each storage buffer in the module.
|
||||
|
@ -156,9 +171,9 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataM
|
|||
|
||||
std::unordered_set<uint32_t> used_size_indices;
|
||||
|
||||
IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr,
|
||||
const sem::VariableUser* storage_buffer_sem,
|
||||
const sem::GlobalVariable* var) {
|
||||
State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
|
||||
const sem::VariableUser* storage_buffer_sem,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
|
|
|
@ -113,6 +113,10 @@ class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Tra
|
|||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
/// The PIMPL state for this transform
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -496,5 +496,43 @@ struct SB {
|
|||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, CallStatement) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
arr : array<i32>,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a : SB;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
arrayLength(&a.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct SB {
|
||||
arr : array<i32>,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a : SB;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -130,6 +130,16 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
// We're dealing with an arrayLength() call
|
||||
|
||||
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
|
||||
if (call_stmt->expr == call_expr) {
|
||||
// arrayLength() is used as a statement.
|
||||
// The argument expression must be side-effect free, so just drop the
|
||||
// statement.
|
||||
RemoveStatement(ctx, call_stmt);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// A runtime-sized array can only appear as the store type of a variable, or the
|
||||
// last element of a structure (which cannot itself be nested). Given that we
|
||||
// require SimplifyPointers, we can assume that the arrayLength() call has one
|
||||
|
|
|
@ -547,5 +547,37 @@ struct SB2 {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, CallStatement) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
arr : array<i32>,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a : SB;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
arrayLength(&a.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct SB {
|
||||
arr : array<i32>,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a : SB;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage> G : array<i32>;
|
||||
|
||||
fn n() {
|
||||
let p = &G;
|
||||
arrayLength(p);
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
|
||||
ByteAddressBuffer G : register(t0, space0);
|
||||
|
||||
void n() {
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
|
||||
ByteAddressBuffer G : register(t0, space0);
|
||||
|
||||
void n() {
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
#version 310 es
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
layout(binding = 0, std430) buffer G_block_ssbo {
|
||||
int inner[];
|
||||
} G;
|
||||
|
||||
void n() {
|
||||
uint(G.inner.length());
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void n() {
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 14
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
|
||||
OpExecutionMode %unused_entry_point LocalSize 1 1 1
|
||||
OpName %G_block "G_block"
|
||||
OpMemberName %G_block 0 "inner"
|
||||
OpName %G "G"
|
||||
OpName %unused_entry_point "unused_entry_point"
|
||||
OpName %n "n"
|
||||
OpDecorate %G_block Block
|
||||
OpMemberDecorate %G_block 0 Offset 0
|
||||
OpDecorate %_runtimearr_int ArrayStride 4
|
||||
OpDecorate %G NonWritable
|
||||
OpDecorate %G DescriptorSet 0
|
||||
OpDecorate %G Binding 0
|
||||
%int = OpTypeInt 32 1
|
||||
%_runtimearr_int = OpTypeRuntimeArray %int
|
||||
%G_block = OpTypeStruct %_runtimearr_int
|
||||
%_ptr_StorageBuffer_G_block = OpTypePointer StorageBuffer %G_block
|
||||
%G = OpVariable %_ptr_StorageBuffer_G_block StorageBuffer
|
||||
%void = OpTypeVoid
|
||||
%6 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%unused_entry_point = OpFunction %void None %6
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%n = OpFunction %void None %6
|
||||
%11 = OpLabel
|
||||
%12 = OpArrayLength %uint %G 0
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,6 @@
|
|||
@group(0) @binding(0) var<storage> G : array<i32>;
|
||||
|
||||
fn n() {
|
||||
let p = &(G);
|
||||
arrayLength(p);
|
||||
}
|
Loading…
Reference in New Issue