[[block]] struct Uniforms { dimAOuter : u32; dimInner : u32; dimBOuter : u32; }; [[block]] struct Matrix { numbers: array; }; [[group(0), binding(0)]] var firstMatrix : Matrix; [[group(0), binding(1)]] var secondMatrix : Matrix; [[group(0), binding(2)]] var resultMatrix : Matrix; [[group(0), binding(3)]] var uniforms : Uniforms; fn mm_readA(row : u32, col : u32) -> f32 { if (row < uniforms.dimAOuter && col < uniforms.dimInner) { let result : f32 = firstMatrix.numbers[row * uniforms.dimInner + col]; return result; } return 0.; } fn mm_readB(row : u32, col : u32) -> f32 { if (row < uniforms.dimInner && col < uniforms.dimBOuter) { let result : f32 = secondMatrix.numbers[row * uniforms.dimBOuter + col]; return result; } return 0.; } fn mm_write(row : u32, col : u32, value : f32) { if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { let index : u32 = col + row * uniforms.dimBOuter; resultMatrix.numbers[index] = value; } } let RowPerThread : u32 = 4u; let ColPerThread : u32 = 4u; let TileAOuter : u32 = 64u; let TileBOuter : u32 = 64u; let TileInner : u32 = 64u; var mm_Asub : array, 64>; var mm_Bsub : array, 64>; [[stage(compute), workgroup_size(16, 16, 1)]] fn main([[builtin(local_invocation_id)]] local_id : vec3, [[builtin(global_invocation_id)]] global_id : vec3) { let tileRow : u32 = local_id.y * RowPerThread; let tileCol : u32 = local_id.x * ColPerThread; let globalRow : u32 = global_id.y * RowPerThread; let globalCol : u32 = global_id.x * ColPerThread; let numTiles : u32 = (uniforms.dimInner - 1u) / TileInner + 1u; var acc: array; var ACached : f32; var BCached : array; // Without this initialization strange values show up in acc. // TODO: Remove it once the following bug is fixed. // https://bugs.chromium.org/p/tint/issues/detail?id=759 for (var index : u32 = 0u; index < RowPerThread * ColPerThread; index = index + 1u) { acc[index] = 0.; } let ColPerThreadA : u32 = TileInner / 16u; let tileColA : u32 = local_id.x * ColPerThreadA; let RowPerThreadB : u32 = TileInner / 16u; let tileRowB : u32 = local_id.y * RowPerThreadB; // Loop over shared dimension. for (var t : u32 = 0u; t < numTiles; t = t + 1u) { // Load one tile of A into local memory. for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) { for (var innerCol : u32 = 0u; innerCol < ColPerThreadA; innerCol = innerCol + 1u) { let inputRow : u32 = tileRow + innerRow; let inputCol : u32 = tileColA + innerCol; mm_Asub[inputRow][inputCol] = mm_readA(globalRow + innerRow, t * TileInner + inputCol); } } // Load one tile of B into local memory. for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) { for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) { let inputRow : u32 = tileRowB + innerRow; let inputCol : u32 = tileCol + innerCol; mm_Bsub[innerCol][inputCol] = mm_readB(t * TileInner + inputRow, globalCol + innerCol);; } } workgroupBarrier(); // Compute acc values for a single thread. for (var k : u32 = 0u; k < TileInner; k = k + 1u) { for (var inner : u32 = 0u; inner < ColPerThread; inner = inner + 1u) { BCached[inner] = mm_Bsub[k][tileCol + inner]; } for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) { ACached = mm_Asub[tileRow + innerRow][k]; for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) { let index : u32 = innerRow * ColPerThread + innerCol; acc[index] = acc[index] + ACached * BCached[innerCol]; } } } workgroupBarrier(); } for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) { for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) { let index : u32 = innerRow * ColPerThread + innerCol; mm_write(globalRow + innerRow, globalCol + innerCol, acc[index]); } } }