[[block]] struct Params { filterDim : u32; blockDim : u32; }; [[group(0), binding(0)]] var samp : sampler; [[group(0), binding(1)]] var params : Params; [[group(1), binding(1)]] var inputTex : texture_2d; [[group(1), binding(2)]] var outputTex : texture_storage_2d; [[block]] struct Flip { value : u32; }; [[group(1), binding(3)]] var flip : Flip; // This shader blurs the input texture in one direction, depending on whether // |flip.value| is 0 or 1. // It does so by running (256 / 4) threads per workgroup to load 256 // texels into 4 rows of shared memory. Each thread loads a // 4 x 4 block of texels to take advantage of the texture sampling // hardware. // Then, each thread computes the blur result by averaging the adjacent texel values // in shared memory. // Because we're operating on a subset of the texture, we cannot compute all of the // results since not all of the neighbors are available in shared memory. // Specifically, with 256 x 256 tiles, we can only compute and write out // square blocks of size 256 - (filterSize - 1). We compute the number of blocks // needed in Javascript and dispatch that amount. var tile : array, 256>, 4>; [[stage(compute), workgroup_size(64, 1, 1)]] fn main( [[builtin(workgroup_id)]] WorkGroupID : vec3, [[builtin(local_invocation_id)]] LocalInvocationID : vec3 ) { let filterOffset : u32 = (params.filterDim - 1u) / 2u; let dims : vec2 = textureDimensions(inputTex, 0); let baseIndex = vec2( WorkGroupID.xy * vec2(params.blockDim, 4u) + LocalInvocationID.xy * vec2(4u, 1u) ) - vec2(i32(filterOffset), 0); for (var r : u32 = 0u; r < 4u; r = r + 1u) { for (var c : u32 = 0u; c < 4u; c = c + 1u) { var loadIndex = baseIndex + vec2(i32(c), i32(r)); if (flip.value != 0u) { loadIndex = loadIndex.yx; } tile[r][4u * LocalInvocationID.x + c] = textureSampleLevel(inputTex, samp, (vec2(loadIndex) + vec2(0.25, 0.25)) / vec2(dims), 0.0).rgb; } } workgroupBarrier(); for (var r : u32 = 0u; r < 4u; r = r + 1u) { for (var c : u32 = 0u; c < 4u; c = c + 1u) { var writeIndex = baseIndex + vec2(i32(c), i32(r)); if (flip.value != 0u) { writeIndex = writeIndex.yx; } let center : u32 = 4u * LocalInvocationID.x + c; if (center >= filterOffset && center < 256u - filterOffset && all(writeIndex < dims)) { var acc : vec3 = vec3(0.0, 0.0, 0.0); for (var f : u32 = 0u; f < params.filterDim; f = f + 1u) { var i : u32 = center + f - filterOffset; acc = acc + (1.0 / f32(params.filterDim)) * tile[r][i]; } textureStore(outputTex, writeIndex, vec4(acc, 1.0)); } } } }