[[block]] struct Uniforms { dimAOuter : u32; dimInner : u32; dimBOuter : u32; }; [[block]] struct Matrix { numbers : array; }; [[group(0), binding(0)]] var firstMatrix : Matrix; [[group(0), binding(1)]] var secondMatrix : Matrix; [[group(0), binding(2)]] var resultMatrix : Matrix; [[group(0), binding(3)]] var uniforms : Uniforms; fn mm_readA(row : u32, col : u32) -> f32 { if (((row < uniforms.dimAOuter) && (col < uniforms.dimInner))) { let result : f32 = firstMatrix.numbers[((row * uniforms.dimInner) + col)]; return result; } return 0.0; } fn mm_readB(row : u32, col : u32) -> f32 { if (((row < uniforms.dimInner) && (col < uniforms.dimBOuter))) { let result : f32 = secondMatrix.numbers[((row * uniforms.dimBOuter) + col)]; return result; } return 0.0; } fn mm_write(row : u32, col : u32, value : f32) { if (((row < uniforms.dimAOuter) && (col < uniforms.dimBOuter))) { let index : u32 = (col + (row * uniforms.dimBOuter)); resultMatrix.numbers[index] = value; } } let RowPerThread : u32 = 4u; let ColPerThread : u32 = 4u; let TileAOuter : u32 = 64u; let TileBOuter : u32 = 64u; let TileInner : u32 = 64u; var mm_Asub : array, 64>; var mm_Bsub : array, 64>; [[stage(compute), workgroup_size(16, 16, 1)]] fn main([[builtin(local_invocation_id)]] local_id : vec3, [[builtin(global_invocation_id)]] global_id : vec3) { let tileRow : u32 = (local_id.y * RowPerThread); let tileCol : u32 = (local_id.x * ColPerThread); let globalRow : u32 = (global_id.y * RowPerThread); let globalCol : u32 = (global_id.x * ColPerThread); let numTiles : u32 = (((uniforms.dimInner - 1u) / TileInner) + 1u); var acc : array; var ACached : f32; var BCached : array; { var index : u32 = 0u; loop { if (!((index < (RowPerThread * ColPerThread)))) { break; } acc[index] = 0.0; continuing { index = (index + 1u); } } } let ColPerThreadA : u32 = (TileInner / 16u); let tileColA : u32 = (local_id.x * ColPerThreadA); let RowPerThreadB : u32 = (TileInner / 16u); let tileRowB : u32 = (local_id.y * RowPerThreadB); { var t : u32 = 0u; loop { if (!((t < numTiles))) { break; } { var innerRow : u32 = 0u; loop { if (!((innerRow < RowPerThread))) { break; } { var innerCol : u32 = 0u; loop { if (!((innerCol < ColPerThreadA))) { break; } let inputRow : u32 = (tileRow + innerRow); let inputCol : u32 = (tileColA + innerCol); mm_Asub[inputRow][inputCol] = mm_readA((globalRow + innerRow), ((t * TileInner) + inputCol)); continuing { innerCol = (innerCol + 1u); } } } continuing { innerRow = (innerRow + 1u); } } } { var innerRow : u32 = 0u; loop { if (!((innerRow < RowPerThreadB))) { break; } { var innerCol : u32 = 0u; loop { if (!((innerCol < ColPerThread))) { break; } let inputRow : u32 = (tileRowB + innerRow); let inputCol : u32 = (tileCol + innerCol); mm_Bsub[innerCol][inputCol] = mm_readB(((t * TileInner) + inputRow), (globalCol + innerCol)); continuing { innerCol = (innerCol + 1u); } } } continuing { innerRow = (innerRow + 1u); } } } workgroupBarrier(); { var k : u32 = 0u; loop { if (!((k < TileInner))) { break; } { var inner : u32 = 0u; loop { if (!((inner < ColPerThread))) { break; } BCached[inner] = mm_Bsub[k][(tileCol + inner)]; continuing { inner = (inner + 1u); } } } { var innerRow : u32 = 0u; loop { if (!((innerRow < RowPerThread))) { break; } ACached = mm_Asub[(tileRow + innerRow)][k]; { var innerCol : u32 = 0u; loop { if (!((innerCol < ColPerThread))) { break; } let index : u32 = ((innerRow * ColPerThread) + innerCol); acc[index] = (acc[index] + (ACached * BCached[innerCol])); continuing { innerCol = (innerCol + 1u); } } } continuing { innerRow = (innerRow + 1u); } } } continuing { k = (k + 1u); } } } workgroupBarrier(); continuing { t = (t + 1u); } } } { var innerRow : u32 = 0u; loop { if (!((innerRow < RowPerThread))) { break; } { var innerCol : u32 = 0u; loop { if (!((innerCol < ColPerThread))) { break; } let index : u32 = ((innerRow * ColPerThread) + innerCol); mm_write((globalRow + innerRow), (globalCol + innerCol), acc[index]); continuing { innerCol = (innerCol + 1u); } } } continuing { innerRow = (innerRow + 1u); } } } }