[hlsl] transform: Zero init arrays with a loop

If the array size is greater than a threshold.
This is a work around for FXC stalling when initializing large arrays
with a single zero-init assignment.

Bug: tint:936
Fixed: tint:943
Fixed: tint:942
Change-Id: Ie93c8f373874b8d6d020d041fa48b38fb1352f71
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56775
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-07-05 17:18:16 +00:00 committed by Tint LUCI CQ
parent b0455217fa
commit b4ff73e250
11 changed files with 905 additions and 23 deletions

View File

@ -66,6 +66,10 @@ Output Hlsl::Run(const Program* in, const DataMap&) {
manager.Add<ExternalTextureTransform>();
manager.Add<PromoteInitializersToConstVar>();
manager.Add<PadArrayElements>();
ZeroInitWorkgroupMemory::Config zero_init_cfg;
zero_init_cfg.init_arrays_with_loop_size_threshold = 32; // 8 scalars
data.Add<ZeroInitWorkgroupMemory::Config>(zero_init_cfg);
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::BuiltinStyle::kStructMember);
auto out = manager.Run(in, data);

View File

@ -24,6 +24,7 @@
#include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);
TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory::Config);
namespace tint {
namespace transform {
@ -32,14 +33,16 @@ namespace transform {
struct ZeroInitWorkgroupMemory::State {
/// The clone context
CloneContext& ctx;
/// The built statements
ast::StatementList& stmts;
/// The config
Config cfg;
/// Zero() generates the statements required to zero initialize the workgroup
/// storage expression of type `ty`.
/// @param ty the expression type
/// @param stmts the built statements
/// @param get_expr a function that builds the AST nodes for the expression
void Zero(const sem::Type* ty,
ast::StatementList& stmts,
const std::function<ast::Expression*()>& get_expr) {
if (CanZero(ty)) {
auto* var = get_expr();
@ -61,22 +64,33 @@ struct ZeroInitWorkgroupMemory::State {
if (auto* str = ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto name = ctx.Clone(member->Declaration()->symbol());
Zero(member->Type(),
Zero(member->Type(), stmts,
[&] { return ctx.dst->MemberAccessor(get_expr(), name); });
}
return;
}
if (auto* arr = ty->As<sem::Array>()) {
// TODO(bclayton): If array sizes become pipeline-overridable then this
// will need to emit code for a loop.
// See https://github.com/gpuweb/gpuweb/pull/1792
if (ShouldEmitForLoop(arr)) {
auto i = ctx.dst->Symbols().New("i");
auto* i_decl = ctx.dst->Decl(ctx.dst->Var(i, ctx.dst->ty.i32()));
auto* cond = ctx.dst->create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, ctx.dst->Expr(i),
ctx.dst->Expr(static_cast<int>(arr->Count())));
auto* inc = ctx.dst->Assign(i, ctx.dst->Add(i, 1));
ast::StatementList for_stmts;
Zero(arr->ElemType(), for_stmts,
[&] { return ctx.dst->IndexAccessor(get_expr(), i); });
auto* body = ctx.dst->Block(for_stmts);
stmts.emplace_back(ctx.dst->For(i_decl, cond, inc, body));
} else {
for (size_t i = 0; i < arr->Count(); i++) {
Zero(arr->ElemType(), [&] {
Zero(arr->ElemType(), stmts, [&] {
return ctx.dst->IndexAccessor(get_expr(),
static_cast<ProgramBuilder::u32>(i));
});
}
}
return;
}
@ -89,7 +103,7 @@ struct ZeroInitWorkgroupMemory::State {
/// CanZero() returns false, then the type needs to be initialized by
/// decomposing the initialization into multiple sub-initializations.
/// @param ty the type to inspect
static bool CanZero(const sem::Type* ty) {
bool CanZero(const sem::Type* ty) {
if (ty->Is<sem::Atomic>()) {
return false;
}
@ -101,21 +115,39 @@ struct ZeroInitWorkgroupMemory::State {
}
}
if (auto* arr = ty->As<sem::Array>()) {
if (!CanZero(arr->ElemType())) {
if (ShouldEmitForLoop(arr) || !CanZero(arr->ElemType())) {
return false;
}
}
return true;
}
/// @returns true if the array should be emitted as a for-loop instead of
/// using zero-initializer statements.
/// @param array the array
bool ShouldEmitForLoop(const sem::Array* array) {
// TODO(bclayton): If array sizes become pipeline-overridable then this
// we need to return true for these arrays.
// See https://github.com/gpuweb/gpuweb/pull/1792
return (cfg.init_arrays_with_loop_size_threshold != 0) &&
(array->SizeInBytes() >= cfg.init_arrays_with_loop_size_threshold);
}
};
ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) {
auto& sem = ctx.src->Sem();
Config cfg;
if (auto* c = inputs.Get<Config>()) {
cfg = *c;
}
for (auto* ast_func : ctx.src->AST().Functions()) {
if (!ast_func->IsEntryPoint()) {
continue;
@ -129,7 +161,7 @@ void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
if (var->StorageClass() != ast::StorageClass::kWorkgroup) {
continue;
}
State{ctx, stmts}.Zero(var->Type()->UnwrapRef(), [&] {
State{ctx, cfg}.Zero(var->Type()->UnwrapRef(), stmts, [&] {
auto var_name = ctx.Clone(var->Declaration()->symbol());
return ctx.dst->Expr(var_name);
});
@ -193,5 +225,11 @@ void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ctx.Clone();
}
ZeroInitWorkgroupMemory::Config::Config() = default;
ZeroInitWorkgroupMemory::Config::Config(const Config&) = default;
ZeroInitWorkgroupMemory::Config::~Config() = default;
ZeroInitWorkgroupMemory::Config& ZeroInitWorkgroupMemory::Config::operator=(
const Config&) = default;
} // namespace transform
} // namespace tint

View File

@ -26,6 +26,27 @@ namespace transform {
class ZeroInitWorkgroupMemory
: public Castable<ZeroInitWorkgroupMemory, Transform> {
public:
/// Configuration options for the transform
struct Config : public Castable<Config, Data> {
/// Constructor
Config();
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// Assignment operator
/// @returns this Config
Config& operator=(const Config&);
/// If greater than 0, then arrays of at least this size in bytes will be
/// zero initialized using a for loop. If 0, then the array is assigned a
/// zero initialized array with a single statement.
uint32_t init_arrays_with_loop_size_threshold = 0;
};
/// Constructor
ZeroInitWorkgroupMemory();

View File

@ -558,6 +558,56 @@ fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
EXPECT_EQ(expect, str(got));
}
TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArray_InitWithLoop) {
auto* src = R"(
struct S {
a : array<i32, 3>; // size: 12, less than the loop threshold
b : array<i32, 4>; // size: 16, equal to the loop threshold
c : array<i32, 5>; // size: 20, greater than the loop threshold
};
var<workgroup> w : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
ignore(w); // Initialization should be inserted above this statement
}
)";
auto* expect = R"(
struct S {
a : array<i32, 3>;
b : array<i32, 4>;
c : array<i32, 5>;
};
var<workgroup> w : S;
[[stage(compute), workgroup_size(1)]]
fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
if ((local_invocation_index == 0u)) {
w.a = array<i32, 3>();
for(var i : i32; (i < 4); i = (i + 1)) {
w.b[i] = i32();
}
for(var i_1 : i32; (i_1 < 5); i_1 = (i_1 + 1)) {
w.c[i_1] = i32();
}
}
workgroupBarrier();
ignore(w);
}
)";
ZeroInitWorkgroupMemory::Config cfg;
cfg.init_arrays_with_loop_size_threshold = 16;
DataMap data;
data.Add<ZeroInitWorkgroupMemory::Config>(cfg);
auto got = Run<ZeroInitWorkgroupMemory>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@ -69,10 +69,16 @@ void main(tint_symbol_1 tint_symbol) {
const uint3 global_id = tint_symbol.global_id;
const uint local_invocation_index = tint_symbol.local_invocation_index;
if ((local_invocation_index == 0u)) {
const float tint_symbol_5[64][64] = (float[64][64])0;
mm_Asub = tint_symbol_5;
const float tint_symbol_6[64][64] = (float[64][64])0;
mm_Bsub = tint_symbol_6;
for(int i = 0; (i < 64); i = (i + 1)) {
for(int i_1 = 0; (i_1 < 64); i_1 = (i_1 + 1)) {
mm_Asub[i][i_1] = 0.0f;
}
}
for(int i_2 = 0; (i_2 < 64); i_2 = (i_2 + 1)) {
for(int i_3 = 0; (i_3 < 64); i_3 = (i_3 + 1)) {
mm_Bsub[i_2][i_3] = 0.0f;
}
}
}
GroupMemoryBarrierWithGroupSync();
const uint tileRow = (local_id.y * RowPerThread);

80
test/bug/tint/942.wgsl Normal file
View File

@ -0,0 +1,80 @@
[[block]] struct Params {
filterDim : u32;
blockDim : u32;
};
[[group(0), binding(0)]] var samp : sampler;
[[group(0), binding(1)]] var<uniform> params : Params;
[[group(1), binding(1)]] var inputTex : texture_2d<f32>;
[[group(1), binding(2)]] var outputTex : texture_storage_2d<rgba8unorm, write>;
[[block]] struct Flip {
value : u32;
};
[[group(1), binding(3)]] var<uniform> flip : Flip;
// This shader blurs the input texture in one direction, depending on whether
// |flip.value| is 0 or 1.
// It does so by running (256 / 4) threads per workgroup to load 256
// texels into 4 rows of shared memory. Each thread loads a
// 4 x 4 block of texels to take advantage of the texture sampling
// hardware.
// Then, each thread computes the blur result by averaging the adjacent texel values
// in shared memory.
// Because we're operating on a subset of the texture, we cannot compute all of the
// results since not all of the neighbors are available in shared memory.
// Specifically, with 256 x 256 tiles, we can only compute and write out
// square blocks of size 256 - (filterSize - 1). We compute the number of blocks
// needed in Javascript and dispatch that amount.
var<workgroup> tile : array<array<vec3<f32>, 256>, 4>;
[[stage(compute), workgroup_size(64, 1, 1)]]
fn main(
[[builtin(workgroup_id)]] WorkGroupID : vec3<u32>,
[[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>
) {
let filterOffset : u32 = (params.filterDim - 1u) / 2u;
let dims : vec2<i32> = textureDimensions(inputTex, 0);
let baseIndex = vec2<i32>(
WorkGroupID.xy * vec2<u32>(params.blockDim, 4u) +
LocalInvocationID.xy * vec2<u32>(4u, 1u)
) - vec2<i32>(i32(filterOffset), 0);
for (var r : u32 = 0u; r < 4u; r = r + 1u) {
for (var c : u32 = 0u; c < 4u; c = c + 1u) {
var loadIndex = baseIndex + vec2<i32>(i32(c), i32(r));
if (flip.value != 0u) {
loadIndex = loadIndex.yx;
}
tile[r][4u * LocalInvocationID.x + c] =
textureSampleLevel(inputTex, samp,
(vec2<f32>(loadIndex) + vec2<f32>(0.25, 0.25)) / vec2<f32>(dims), 0.0).rgb;
}
}
workgroupBarrier();
for (var r : u32 = 0u; r < 4u; r = r + 1u) {
for (var c : u32 = 0u; c < 4u; c = c + 1u) {
var writeIndex = baseIndex + vec2<i32>(i32(c), i32(r));
if (flip.value != 0u) {
writeIndex = writeIndex.yx;
}
let center : u32 = 4u * LocalInvocationID.x + c;
if (center >= filterOffset &&
center < 256u - filterOffset &&
all(writeIndex < dims)) {
var acc : vec3<f32> = vec3<f32>(0.0, 0.0, 0.0);
for (var f : u32 = 0u; f < params.filterDim; f = f + 1u) {
var i : u32 = center + f - filterOffset;
acc = acc + (1.0 / f32(params.filterDim)) * tile[r][i];
}
textureStore(outputTex, writeIndex, vec4<f32>(acc, 1.0));
}
}
}
}

View File

@ -0,0 +1,96 @@
SamplerState samp : register(s0, space0);
cbuffer cbuffer_params : register(b1, space0) {
uint4 params[1];
};
Texture2D<float4> inputTex : register(t1, space1);
RWTexture2D<float4> outputTex : register(u2, space1);
cbuffer cbuffer_flip : register(b3, space1) {
uint4 flip[1];
};
groupshared float3 tile[4][256];
struct tint_symbol_1 {
uint3 LocalInvocationID : SV_GroupThreadID;
uint local_invocation_index : SV_GroupIndex;
uint3 WorkGroupID : SV_GroupID;
};
[numthreads(64, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
const uint3 WorkGroupID = tint_symbol.WorkGroupID;
const uint3 LocalInvocationID = tint_symbol.LocalInvocationID;
const uint local_invocation_index = tint_symbol.local_invocation_index;
if ((local_invocation_index == 0u)) {
for(int i_1 = 0; (i_1 < 4); i_1 = (i_1 + 1)) {
for(int i_2 = 0; (i_2 < 256); i_2 = (i_2 + 1)) {
tile[i_1][i_2] = float3(0.0f, 0.0f, 0.0f);
}
}
}
GroupMemoryBarrierWithGroupSync();
const uint scalar_offset = (0u) / 4;
const uint filterOffset = ((params[scalar_offset / 4][scalar_offset % 4] - 1u) / 2u);
int3 tint_tmp;
inputTex.GetDimensions(0, tint_tmp.x, tint_tmp.y, tint_tmp.z);
const int2 dims = tint_tmp.xy;
const uint scalar_offset_1 = (4u) / 4;
const int2 baseIndex = (int2(((WorkGroupID.xy * uint2(params[scalar_offset_1 / 4][scalar_offset_1 % 4], 4u)) + (LocalInvocationID.xy * uint2(4u, 1u)))) - int2(int(filterOffset), 0));
{
uint r = 0u;
for(; !(!((r < 4u))); r = (r + 1u)) {
{
uint c = 0u;
for(; !(!((c < 4u))); c = (c + 1u)) {
int2 loadIndex = (baseIndex + int2(int(c), int(r)));
const uint scalar_offset_2 = (0u) / 4;
if ((flip[scalar_offset_2 / 4][scalar_offset_2 % 4] != 0u)) {
loadIndex = loadIndex.yx;
}
tile[r][((4u * LocalInvocationID.x) + c)] = inputTex.SampleLevel(samp, ((float2(loadIndex) + float2(0.25f, 0.25f)) / float2(dims)), 0.0f).rgb;
}
}
}
}
GroupMemoryBarrierWithGroupSync();
{
uint r = 0u;
for(; !(!((r < 4u))); r = (r + 1u)) {
{
uint c = 0u;
for(; !(!((c < 4u))); c = (c + 1u)) {
int2 writeIndex = (baseIndex + int2(int(c), int(r)));
const uint scalar_offset_3 = (0u) / 4;
if ((flip[scalar_offset_3 / 4][scalar_offset_3 % 4] != 0u)) {
writeIndex = writeIndex.yx;
}
const uint center = ((4u * LocalInvocationID.x) + c);
bool tint_tmp_2 = (center >= filterOffset);
if (tint_tmp_2) {
tint_tmp_2 = (center < (256u - filterOffset));
}
bool tint_tmp_1 = (tint_tmp_2);
if (tint_tmp_1) {
tint_tmp_1 = all((writeIndex < dims));
}
if ((tint_tmp_1)) {
float3 acc = float3(0.0f, 0.0f, 0.0f);
{
uint f = 0u;
while (true) {
const uint scalar_offset_4 = (0u) / 4;
if (!(!(!((f < params[scalar_offset_4 / 4][scalar_offset_4 % 4]))))) { break; }
uint i = ((center + f) - filterOffset);
const uint scalar_offset_5 = (0u) / 4;
acc = (acc + ((1.0f / float(params[scalar_offset_5 / 4][scalar_offset_5 % 4])) * tile[r][i]));
f = (f + 1u);
}
}
outputTex[writeIndex] = float4(acc, 1.0f);
}
}
}
}
}
return;
}

View File

@ -0,0 +1,102 @@
#include <metal_stdlib>
using namespace metal;
struct Params {
/* 0x0000 */ uint filterDim;
/* 0x0004 */ uint blockDim;
};
struct Flip {
/* 0x0000 */ uint value;
};
struct tint_array_wrapper_1 {
float3 arr[256];
};
struct tint_array_wrapper {
tint_array_wrapper_1 arr[4];
};
kernel void tint_symbol(texture2d<float, access::sample> tint_symbol_4 [[texture(1)]], sampler tint_symbol_5 [[sampler(0)]], texture2d<float, access::write> tint_symbol_6 [[texture(2)]], uint3 WorkGroupID [[threadgroup_position_in_grid]], uint3 LocalInvocationID [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Params& params [[buffer(1)]], constant Flip& flip [[buffer(3)]]) {
threadgroup tint_array_wrapper tint_symbol_3;
if ((local_invocation_index == 0u)) {
tint_array_wrapper const tint_symbol_2 = {.arr={}};
tint_symbol_3 = tint_symbol_2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint const filterOffset = ((params.filterDim - 1u) / 2u);
int2 const dims = int2(tint_symbol_4.get_width(0), tint_symbol_4.get_height(0));
int2 const baseIndex = (int2(((WorkGroupID.xy * uint2(params.blockDim, 4u)) + (LocalInvocationID.xy * uint2(4u, 1u)))) - int2(int(filterOffset), 0));
{
uint r = 0u;
while (true) {
if (!((r < 4u))) {
break;
}
{
uint c = 0u;
while (true) {
if (!((c < 4u))) {
break;
}
int2 loadIndex = (baseIndex + int2(int(c), int(r)));
if ((flip.value != 0u)) {
loadIndex = loadIndex.yx;
}
tint_symbol_3.arr[r].arr[((4u * LocalInvocationID.x) + c)] = tint_symbol_4.sample(tint_symbol_5, ((float2(loadIndex) + float2(0.25f, 0.25f)) / float2(dims)), level(0.0f)).rgb;
{
c = (c + 1u);
}
}
}
{
r = (r + 1u);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
uint r = 0u;
while (true) {
if (!((r < 4u))) {
break;
}
{
uint c = 0u;
while (true) {
if (!((c < 4u))) {
break;
}
int2 writeIndex = (baseIndex + int2(int(c), int(r)));
if ((flip.value != 0u)) {
writeIndex = writeIndex.yx;
}
uint const center = ((4u * LocalInvocationID.x) + c);
if ((((center >= filterOffset) && (center < (256u - filterOffset))) && all((writeIndex < dims)))) {
float3 acc = float3(0.0f, 0.0f, 0.0f);
{
uint f = 0u;
while (true) {
if (!((f < params.filterDim))) {
break;
}
uint i = ((center + f) - filterOffset);
acc = (acc + ((1.0f / float(params.filterDim)) * tint_symbol_3.arr[r].arr[i]));
{
f = (f + 1u);
}
}
}
tint_symbol_6.write(float4(acc, 1.0f), uint2(writeIndex));
}
{
c = (c + 1u);
}
}
}
{
r = (r + 1u);
}
}
}
return;
}

View File

@ -0,0 +1,374 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 239
; Schema: 0
OpCapability Shader
OpCapability ImageQuery
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %tint_symbol_2 %tint_symbol %tint_symbol_1
OpExecutionMode %main LocalSize 64 1 1
OpName %samp "samp"
OpName %Params "Params"
OpMemberName %Params 0 "filterDim"
OpMemberName %Params 1 "blockDim"
OpName %params "params"
OpName %inputTex "inputTex"
OpName %outputTex "outputTex"
OpName %Flip "Flip"
OpMemberName %Flip 0 "value"
OpName %flip "flip"
OpName %tile "tile"
OpName %tint_symbol "tint_symbol"
OpName %tint_symbol_1 "tint_symbol_1"
OpName %tint_symbol_2 "tint_symbol_2"
OpName %main "main"
OpName %r "r"
OpName %c "c"
OpName %loadIndex "loadIndex"
OpName %r_0 "r"
OpName %c_0 "c"
OpName %writeIndex "writeIndex"
OpName %acc "acc"
OpName %f "f"
OpName %i "i"
OpDecorate %samp DescriptorSet 0
OpDecorate %samp Binding 0
OpDecorate %Params Block
OpMemberDecorate %Params 0 Offset 0
OpMemberDecorate %Params 1 Offset 4
OpDecorate %params NonWritable
OpDecorate %params DescriptorSet 0
OpDecorate %params Binding 1
OpDecorate %inputTex DescriptorSet 1
OpDecorate %inputTex Binding 1
OpDecorate %outputTex NonReadable
OpDecorate %outputTex DescriptorSet 1
OpDecorate %outputTex Binding 2
OpDecorate %Flip Block
OpMemberDecorate %Flip 0 Offset 0
OpDecorate %flip NonWritable
OpDecorate %flip DescriptorSet 1
OpDecorate %flip Binding 3
OpDecorate %_arr_v3float_uint_256 ArrayStride 16
OpDecorate %_arr__arr_v3float_uint_256_uint_4 ArrayStride 4096
OpDecorate %tint_symbol BuiltIn WorkgroupId
OpDecorate %tint_symbol_1 BuiltIn LocalInvocationId
OpDecorate %tint_symbol_2 BuiltIn LocalInvocationIndex
%3 = OpTypeSampler
%_ptr_UniformConstant_3 = OpTypePointer UniformConstant %3
%samp = OpVariable %_ptr_UniformConstant_3 UniformConstant
%uint = OpTypeInt 32 0
%Params = OpTypeStruct %uint %uint
%_ptr_Uniform_Params = OpTypePointer Uniform %Params
%params = OpVariable %_ptr_Uniform_Params Uniform
%float = OpTypeFloat 32
%10 = OpTypeImage %float 2D 0 0 0 1 Unknown
%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10
%inputTex = OpVariable %_ptr_UniformConstant_10 UniformConstant
%14 = OpTypeImage %float 2D 0 0 0 2 Rgba8
%_ptr_UniformConstant_14 = OpTypePointer UniformConstant %14
%outputTex = OpVariable %_ptr_UniformConstant_14 UniformConstant
%Flip = OpTypeStruct %uint
%_ptr_Uniform_Flip = OpTypePointer Uniform %Flip
%flip = OpVariable %_ptr_Uniform_Flip Uniform
%v3float = OpTypeVector %float 3
%uint_256 = OpConstant %uint 256
%_arr_v3float_uint_256 = OpTypeArray %v3float %uint_256
%uint_4 = OpConstant %uint 4
%_arr__arr_v3float_uint_256_uint_4 = OpTypeArray %_arr_v3float_uint_256 %uint_4
%_ptr_Workgroup__arr__arr_v3float_uint_256_uint_4 = OpTypePointer Workgroup %_arr__arr_v3float_uint_256_uint_4
%tile = OpVariable %_ptr_Workgroup__arr__arr_v3float_uint_256_uint_4 Workgroup
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
%tint_symbol = OpVariable %_ptr_Input_v3uint Input
%tint_symbol_1 = OpVariable %_ptr_Input_v3uint Input
%_ptr_Input_uint = OpTypePointer Input %uint
%tint_symbol_2 = OpVariable %_ptr_Input_uint Input
%void = OpTypeVoid
%31 = OpTypeFunction %void
%uint_0 = OpConstant %uint 0
%bool = OpTypeBool
%41 = OpConstantNull %_arr__arr_v3float_uint_256_uint_4
%uint_2 = OpConstant %uint 2
%uint_264 = OpConstant %uint 264
%_ptr_Uniform_uint = OpTypePointer Uniform %uint
%uint_1 = OpConstant %uint 1
%int = OpTypeInt 32 1
%v2int = OpTypeVector %int 2
%int_0 = OpConstant %int 0
%v2uint = OpTypeVector %uint 2
%66 = OpConstantComposite %v2uint %uint_4 %uint_1
%_ptr_Function_uint = OpTypePointer Function %uint
%74 = OpConstantNull %uint
%_ptr_Function_v2int = OpTypePointer Function %v2int
%102 = OpConstantNull %v2int
%_ptr_Workgroup_v3float = OpTypePointer Workgroup %v3float
%v4float = OpTypeVector %float 4
%122 = OpTypeSampledImage %10
%v2float = OpTypeVector %float 2
%float_0_25 = OpConstant %float 0.25
%128 = OpConstantComposite %v2float %float_0_25 %float_0_25
%float_0 = OpConstant %float 0
%v2bool = OpTypeVector %bool 2
%193 = OpConstantComposite %v3float %float_0 %float_0 %float_0
%_ptr_Function_v3float = OpTypePointer Function %v3float
%196 = OpConstantNull %v3float
%float_1 = OpConstant %float 1
%main = OpFunction %void None %31
%34 = OpLabel
%r = OpVariable %_ptr_Function_uint Function %74
%c = OpVariable %_ptr_Function_uint Function %74
%loadIndex = OpVariable %_ptr_Function_v2int Function %102
%r_0 = OpVariable %_ptr_Function_uint Function %74
%c_0 = OpVariable %_ptr_Function_uint Function %74
%writeIndex = OpVariable %_ptr_Function_v2int Function %102
%acc = OpVariable %_ptr_Function_v3float Function %196
%f = OpVariable %_ptr_Function_uint Function %74
%i = OpVariable %_ptr_Function_uint Function %74
%35 = OpLoad %uint %tint_symbol_2
%37 = OpIEqual %bool %35 %uint_0
OpSelectionMerge %39 None
OpBranchConditional %37 %40 %39
%40 = OpLabel
OpStore %tile %41
OpBranch %39
%39 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
%46 = OpAccessChain %_ptr_Uniform_uint %params %uint_0
%47 = OpLoad %uint %46
%49 = OpISub %uint %47 %uint_1
%50 = OpUDiv %uint %49 %uint_2
%54 = OpLoad %10 %inputTex
%51 = OpImageQuerySizeLod %v2int %54 %int_0
%58 = OpLoad %v3uint %tint_symbol
%59 = OpVectorShuffle %v2uint %58 %58 0 1
%60 = OpAccessChain %_ptr_Uniform_uint %params %uint_1
%61 = OpLoad %uint %60
%62 = OpCompositeConstruct %v2uint %61 %uint_4
%63 = OpIMul %v2uint %59 %62
%64 = OpLoad %v3uint %tint_symbol_1
%65 = OpVectorShuffle %v2uint %64 %64 0 1
%67 = OpIMul %v2uint %65 %66
%68 = OpIAdd %v2uint %63 %67
%56 = OpBitcast %v2int %68
%69 = OpBitcast %int %50
%70 = OpCompositeConstruct %v2int %69 %int_0
%71 = OpISub %v2int %56 %70
OpStore %r %uint_0
OpBranch %75
%75 = OpLabel
OpLoopMerge %76 %77 None
OpBranch %78
%78 = OpLabel
%80 = OpLoad %uint %r
%81 = OpULessThan %bool %80 %uint_4
%79 = OpLogicalNot %bool %81
OpSelectionMerge %82 None
OpBranchConditional %79 %83 %82
%83 = OpLabel
OpBranch %76
%82 = OpLabel
OpStore %c %uint_0
OpBranch %85
%85 = OpLabel
OpLoopMerge %86 %87 None
OpBranch %88
%88 = OpLabel
%90 = OpLoad %uint %c
%91 = OpULessThan %bool %90 %uint_4
%89 = OpLogicalNot %bool %91
OpSelectionMerge %92 None
OpBranchConditional %89 %93 %92
%93 = OpLabel
OpBranch %86
%92 = OpLabel
%95 = OpLoad %uint %c
%94 = OpBitcast %int %95
%97 = OpLoad %uint %r
%96 = OpBitcast %int %97
%98 = OpCompositeConstruct %v2int %94 %96
%99 = OpIAdd %v2int %71 %98
OpStore %loadIndex %99
%103 = OpAccessChain %_ptr_Uniform_uint %flip %uint_0
%104 = OpLoad %uint %103
%105 = OpINotEqual %bool %104 %uint_0
OpSelectionMerge %106 None
OpBranchConditional %105 %107 %106
%107 = OpLabel
%108 = OpLoad %v2int %loadIndex
%109 = OpVectorShuffle %v2int %108 %108 1 0
OpStore %loadIndex %109
OpBranch %106
%106 = OpLabel
%110 = OpLoad %uint %r
%111 = OpAccessChain %_ptr_Input_uint %tint_symbol_1 %uint_0
%112 = OpLoad %uint %111
%113 = OpIMul %uint %uint_4 %112
%114 = OpLoad %uint %c
%115 = OpIAdd %uint %113 %114
%117 = OpAccessChain %_ptr_Workgroup_v3float %tile %110 %115
%120 = OpLoad %3 %samp
%121 = OpLoad %10 %inputTex
%123 = OpSampledImage %122 %121 %120
%126 = OpLoad %v2int %loadIndex
%124 = OpConvertSToF %v2float %126
%129 = OpFAdd %v2float %124 %128
%130 = OpConvertSToF %v2float %51
%131 = OpFDiv %v2float %129 %130
%118 = OpImageSampleExplicitLod %v4float %123 %131 Lod %float_0
%133 = OpVectorShuffle %v3float %118 %118 0 1 2
OpStore %117 %133
OpBranch %87
%87 = OpLabel
%134 = OpLoad %uint %c
%135 = OpIAdd %uint %134 %uint_1
OpStore %c %135
OpBranch %85
%86 = OpLabel
OpBranch %77
%77 = OpLabel
%136 = OpLoad %uint %r
%137 = OpIAdd %uint %136 %uint_1
OpStore %r %137
OpBranch %75
%76 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
OpStore %r_0 %uint_0
OpBranch %140
%140 = OpLabel
OpLoopMerge %141 %142 None
OpBranch %143
%143 = OpLabel
%145 = OpLoad %uint %r_0
%146 = OpULessThan %bool %145 %uint_4
%144 = OpLogicalNot %bool %146
OpSelectionMerge %147 None
OpBranchConditional %144 %148 %147
%148 = OpLabel
OpBranch %141
%147 = OpLabel
OpStore %c_0 %uint_0
OpBranch %150
%150 = OpLabel
OpLoopMerge %151 %152 None
OpBranch %153
%153 = OpLabel
%155 = OpLoad %uint %c_0
%156 = OpULessThan %bool %155 %uint_4
%154 = OpLogicalNot %bool %156
OpSelectionMerge %157 None
OpBranchConditional %154 %158 %157
%158 = OpLabel
OpBranch %151
%157 = OpLabel
%160 = OpLoad %uint %c_0
%159 = OpBitcast %int %160
%162 = OpLoad %uint %r_0
%161 = OpBitcast %int %162
%163 = OpCompositeConstruct %v2int %159 %161
%164 = OpIAdd %v2int %71 %163
OpStore %writeIndex %164
%166 = OpAccessChain %_ptr_Uniform_uint %flip %uint_0
%167 = OpLoad %uint %166
%168 = OpINotEqual %bool %167 %uint_0
OpSelectionMerge %169 None
OpBranchConditional %168 %170 %169
%170 = OpLabel
%171 = OpLoad %v2int %writeIndex
%172 = OpVectorShuffle %v2int %171 %171 1 0
OpStore %writeIndex %172
OpBranch %169
%169 = OpLabel
%173 = OpAccessChain %_ptr_Input_uint %tint_symbol_1 %uint_0
%174 = OpLoad %uint %173
%175 = OpIMul %uint %uint_4 %174
%176 = OpLoad %uint %c_0
%177 = OpIAdd %uint %175 %176
%178 = OpUGreaterThanEqual %bool %177 %50
OpSelectionMerge %179 None
OpBranchConditional %178 %180 %179
%180 = OpLabel
%181 = OpISub %uint %uint_256 %50
%182 = OpULessThan %bool %177 %181
OpBranch %179
%179 = OpLabel
%183 = OpPhi %bool %178 %169 %182 %180
OpSelectionMerge %184 None
OpBranchConditional %183 %185 %184
%185 = OpLabel
%187 = OpLoad %v2int %writeIndex
%188 = OpSLessThan %v2bool %187 %51
%186 = OpAll %bool %188
OpBranch %184
%184 = OpLabel
%190 = OpPhi %bool %183 %179 %186 %185
OpSelectionMerge %191 None
OpBranchConditional %190 %192 %191
%192 = OpLabel
OpStore %acc %193
OpStore %f %uint_0
OpBranch %198
%198 = OpLabel
OpLoopMerge %199 %200 None
OpBranch %201
%201 = OpLabel
%203 = OpLoad %uint %f
%204 = OpAccessChain %_ptr_Uniform_uint %params %uint_0
%205 = OpLoad %uint %204
%206 = OpULessThan %bool %203 %205
%202 = OpLogicalNot %bool %206
OpSelectionMerge %207 None
OpBranchConditional %202 %208 %207
%208 = OpLabel
OpBranch %199
%207 = OpLabel
%209 = OpLoad %uint %f
%210 = OpIAdd %uint %177 %209
%211 = OpISub %uint %210 %50
OpStore %i %211
%213 = OpLoad %v3float %acc
%216 = OpAccessChain %_ptr_Uniform_uint %params %uint_0
%217 = OpLoad %uint %216
%215 = OpConvertUToF %float %217
%218 = OpFDiv %float %float_1 %215
%219 = OpLoad %uint %r_0
%220 = OpLoad %uint %i
%221 = OpAccessChain %_ptr_Workgroup_v3float %tile %219 %220
%222 = OpLoad %v3float %221
%223 = OpVectorTimesScalar %v3float %222 %218
%224 = OpFAdd %v3float %213 %223
OpStore %acc %224
OpBranch %200
%200 = OpLabel
%225 = OpLoad %uint %f
%226 = OpIAdd %uint %225 %uint_1
OpStore %f %226
OpBranch %198
%199 = OpLabel
%228 = OpLoad %14 %outputTex
%229 = OpLoad %v2int %writeIndex
%230 = OpLoad %v3float %acc
%231 = OpCompositeExtract %float %230 0
%232 = OpCompositeExtract %float %230 1
%233 = OpCompositeExtract %float %230 2
%234 = OpCompositeConstruct %v4float %231 %232 %233 %float_1
OpImageWrite %228 %229 %234
OpBranch %191
%191 = OpLabel
OpBranch %152
%152 = OpLabel
%235 = OpLoad %uint %c_0
%236 = OpIAdd %uint %235 %uint_1
OpStore %c_0 %236
OpBranch %150
%151 = OpLabel
OpBranch %142
%142 = OpLabel
%237 = OpLoad %uint %r_0
%238 = OpIAdd %uint %237 %uint_1
OpStore %r_0 %238
OpBranch %140
%141 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,106 @@
[[block]]
struct Params {
filterDim : u32;
blockDim : u32;
};
[[group(0), binding(0)]] var samp : sampler;
[[group(0), binding(1)]] var<uniform> params : Params;
[[group(1), binding(1)]] var inputTex : texture_2d<f32>;
[[group(1), binding(2)]] var outputTex : texture_storage_2d<rgba8unorm, write>;
[[block]]
struct Flip {
value : u32;
};
[[group(1), binding(3)]] var<uniform> flip : Flip;
var<workgroup> tile : array<array<vec3<f32>, 256>, 4>;
[[stage(compute), workgroup_size(64, 1, 1)]]
fn main([[builtin(workgroup_id)]] WorkGroupID : vec3<u32>, [[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
let filterOffset : u32 = ((params.filterDim - 1u) / 2u);
let dims : vec2<i32> = textureDimensions(inputTex, 0);
let baseIndex = (vec2<i32>(((WorkGroupID.xy * vec2<u32>(params.blockDim, 4u)) + (LocalInvocationID.xy * vec2<u32>(4u, 1u)))) - vec2<i32>(i32(filterOffset), 0));
{
var r : u32 = 0u;
loop {
if (!((r < 4u))) {
break;
}
{
var c : u32 = 0u;
loop {
if (!((c < 4u))) {
break;
}
var loadIndex = (baseIndex + vec2<i32>(i32(c), i32(r)));
if ((flip.value != 0u)) {
loadIndex = loadIndex.yx;
}
tile[r][((4u * LocalInvocationID.x) + c)] = textureSampleLevel(inputTex, samp, ((vec2<f32>(loadIndex) + vec2<f32>(0.25, 0.25)) / vec2<f32>(dims)), 0.0).rgb;
continuing {
c = (c + 1u);
}
}
}
continuing {
r = (r + 1u);
}
}
}
workgroupBarrier();
{
var r : u32 = 0u;
loop {
if (!((r < 4u))) {
break;
}
{
var c : u32 = 0u;
loop {
if (!((c < 4u))) {
break;
}
var writeIndex = (baseIndex + vec2<i32>(i32(c), i32(r)));
if ((flip.value != 0u)) {
writeIndex = writeIndex.yx;
}
let center : u32 = ((4u * LocalInvocationID.x) + c);
if ((((center >= filterOffset) && (center < (256u - filterOffset))) && all((writeIndex < dims)))) {
var acc : vec3<f32> = vec3<f32>(0.0, 0.0, 0.0);
{
var f : u32 = 0u;
loop {
if (!((f < params.filterDim))) {
break;
}
var i : u32 = ((center + f) - filterOffset);
acc = (acc + ((1.0 / f32(params.filterDim)) * tile[r][i]));
continuing {
f = (f + 1u);
}
}
}
textureStore(outputTex, writeIndex, vec4<f32>(acc, 1.0));
}
continuing {
c = (c + 1u);
}
}
}
continuing {
r = (r + 1u);
}
}
}
}

View File

@ -336,10 +336,15 @@ void main(tint_symbol_1 tint_symbol) {
const uint3 gl_GlobalInvocationID_param = tint_symbol.gl_GlobalInvocationID_param;
const uint local_invocation_index = tint_symbol.local_invocation_index;
if ((local_invocation_index == 0u)) {
const float tint_symbol_6[64][64] = (float[64][64])0;
mm_Asub = tint_symbol_6;
const float tint_symbol_7[64][1] = (float[64][1])0;
mm_Bsub = tint_symbol_7;
for(int i = 0; (i < 64); i = (i + 1)) {
for(int i_1 = 0; (i_1 < 64); i_1 = (i_1 + 1)) {
mm_Asub[i][i_1] = 0.0f;
}
}
for(int i_2 = 0; (i_2 < 64); i_2 = (i_2 + 1)) {
const float tint_symbol_6[1] = (float[1])0;
mm_Bsub[i_2] = tint_symbol_6;
}
}
GroupMemoryBarrierWithGroupSync();
gl_LocalInvocationID = gl_LocalInvocationID_param;