struct Params { filterDim : u32; blockDim : u32; }; @group(0) @binding(0) var samp : sampler; @group(0) @binding(1) var params : Params; @group(1) @binding(1) var inputTex : texture_2d; @group(1) @binding(2) var outputTex : texture_storage_2d; struct Flip { value : u32; }; @group(1) @binding(3) var flip : Flip; // This shader blurs the input texture in one direction, depending on whether // |flip.value| is 0 or 1. // It does so by running (256 / 4) threads per workgroup to load 256 // texels into 4 rows of shared memory. Each thread loads a // 4 x 4 block of texels to take advantage of the texture sampling // hardware. // Then, each thread computes the blur result by averaging the adjacent texel values // in shared memory. // Because we're operating on a subset of the texture, we cannot compute all of the // results since not all of the neighbors are available in shared memory. // Specifically, with 256 x 256 tiles, we can only compute and write out // square blocks of size 256 - (filterSize - 1). We compute the number of blocks // needed in Javascript and dispatch that amount. var tile : array, 256>, 4>; @stage(compute) @workgroup_size(64, 1, 1) fn main( @builtin(workgroup_id) WorkGroupID : vec3, @builtin(local_invocation_id) LocalInvocationID : vec3 ) { let filterOffset : u32 = (params.filterDim - 1u) / 2u; let dims : vec2 = textureDimensions(inputTex, 0); let baseIndex = vec2( WorkGroupID.xy * vec2(params.blockDim, 4u) + LocalInvocationID.xy * vec2(4u, 1u) ) - vec2(i32(filterOffset), 0); for (var r : u32 = 0u; r < 4u; r = r + 1u) { for (var c : u32 = 0u; c < 4u; c = c + 1u) { var loadIndex = baseIndex + vec2(i32(c), i32(r)); if (flip.value != 0u) { loadIndex = loadIndex.yx; } tile[r][4u * LocalInvocationID.x + c] = textureSampleLevel(inputTex, samp, (vec2(loadIndex) + vec2(0.25, 0.25)) / vec2(dims), 0.0).rgb; } } workgroupBarrier(); for (var r : u32 = 0u; r < 4u; r = r + 1u) { for (var c : u32 = 0u; c < 4u; c = c + 1u) { var writeIndex = baseIndex + vec2(i32(c), i32(r)); if (flip.value != 0u) { writeIndex = writeIndex.yx; } let center : u32 = 4u * LocalInvocationID.x + c; if (center >= filterOffset && center < 256u - filterOffset && all(writeIndex < dims)) { var acc : vec3 = vec3(0.0, 0.0, 0.0); for (var f : u32 = 0u; f < params.filterDim; f = f + 1u) { var i : u32 = center + f - filterOffset; acc = acc + (1.0 / f32(params.filterDim)) * tile[r][i]; } textureStore(outputTex, writeIndex, vec4(acc, 1.0)); } } } }