#include using namespace metal; struct Uniforms { /* 0x0000 */ uint dimAOuter; /* 0x0004 */ uint dimInner; /* 0x0008 */ uint dimBOuter; }; struct Matrix { /* 0x0000 */ float numbers[1]; }; struct tint_array_wrapper_1 { float arr[64]; }; struct tint_array_wrapper { tint_array_wrapper_1 arr[64]; }; struct tint_array_wrapper_2 { float arr[16]; }; struct tint_array_wrapper_3 { float arr[4]; }; constant uint RowPerThread = 4u; constant uint ColPerThread = 4u; constant uint TileAOuter = 64u; constant uint TileBOuter = 64u; constant uint TileInner = 64u; float mm_readA(constant Uniforms& uniforms, const device Matrix& firstMatrix, uint row, uint col) { if (((row < uniforms.dimAOuter) && (col < uniforms.dimInner))) { float const result = firstMatrix.numbers[((row * uniforms.dimInner) + col)]; return result; } return 0.0f; } float mm_readB(constant Uniforms& uniforms, const device Matrix& secondMatrix, uint row, uint col) { if (((row < uniforms.dimInner) && (col < uniforms.dimBOuter))) { float const result = secondMatrix.numbers[((row * uniforms.dimBOuter) + col)]; return result; } return 0.0f; } void mm_write(constant Uniforms& uniforms, device Matrix& resultMatrix, uint row, uint col, float value) { if (((row < uniforms.dimAOuter) && (col < uniforms.dimBOuter))) { uint const index = (col + (row * uniforms.dimBOuter)); resultMatrix.numbers[index] = value; } } kernel void tint_symbol(uint3 local_id [[thread_position_in_threadgroup]], uint3 global_id [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Uniforms& uniforms [[buffer(3)]], const device Matrix& firstMatrix [[buffer(0)]], const device Matrix& secondMatrix [[buffer(1)]], device Matrix& resultMatrix [[buffer(2)]]) { threadgroup tint_array_wrapper tint_symbol_4; threadgroup tint_array_wrapper tint_symbol_5; if ((local_invocation_index == 0u)) { tint_array_wrapper const tint_symbol_2 = {.arr={}}; tint_symbol_4 = tint_symbol_2; tint_array_wrapper const tint_symbol_3 = {.arr={}}; tint_symbol_5 = tint_symbol_3; } threadgroup_barrier(mem_flags::mem_threadgroup); uint const tileRow = (local_id.y * RowPerThread); uint const tileCol = (local_id.x * ColPerThread); uint const globalRow = (global_id.y * RowPerThread); uint const globalCol = (global_id.x * ColPerThread); uint const numTiles = (((uniforms.dimInner - 1u) / TileInner) + 1u); tint_array_wrapper_2 acc = {}; float ACached = 0.0f; tint_array_wrapper_3 BCached = {}; { uint index = 0u; while (true) { if (!((index < (RowPerThread * ColPerThread)))) { break; } acc.arr[index] = 0.0f; { index = (index + 1u); } } } uint const ColPerThreadA = (TileInner / 16u); uint const tileColA = (local_id.x * ColPerThreadA); uint const RowPerThreadB = (TileInner / 16u); uint const tileRowB = (local_id.y * RowPerThreadB); { uint t = 0u; while (true) { if (!((t < numTiles))) { break; } { uint innerRow = 0u; while (true) { if (!((innerRow < RowPerThread))) { break; } { uint innerCol = 0u; while (true) { if (!((innerCol < ColPerThreadA))) { break; } uint const inputRow = (tileRow + innerRow); uint const inputCol = (tileColA + innerCol); tint_symbol_4.arr[inputRow].arr[inputCol] = mm_readA(uniforms, firstMatrix, (globalRow + innerRow), ((t * TileInner) + inputCol)); { innerCol = (innerCol + 1u); } } } { innerRow = (innerRow + 1u); } } } { uint innerRow = 0u; while (true) { if (!((innerRow < RowPerThreadB))) { break; } { uint innerCol = 0u; while (true) { if (!((innerCol < ColPerThread))) { break; } uint const inputRow = (tileRowB + innerRow); uint const inputCol = (tileCol + innerCol); tint_symbol_5.arr[innerCol].arr[inputCol] = mm_readB(uniforms, secondMatrix, ((t * TileInner) + inputRow), (globalCol + innerCol)); { innerCol = (innerCol + 1u); } } } { innerRow = (innerRow + 1u); } } } threadgroup_barrier(mem_flags::mem_threadgroup); { uint k = 0u; while (true) { if (!((k < TileInner))) { break; } { uint inner = 0u; while (true) { if (!((inner < ColPerThread))) { break; } BCached.arr[inner] = tint_symbol_5.arr[k].arr[(tileCol + inner)]; { inner = (inner + 1u); } } } { uint innerRow = 0u; while (true) { if (!((innerRow < RowPerThread))) { break; } ACached = tint_symbol_4.arr[(tileRow + innerRow)].arr[k]; { uint innerCol = 0u; while (true) { if (!((innerCol < ColPerThread))) { break; } uint const index = ((innerRow * ColPerThread) + innerCol); acc.arr[index] = (acc.arr[index] + (ACached * BCached.arr[innerCol])); { innerCol = (innerCol + 1u); } } } { innerRow = (innerRow + 1u); } } } { k = (k + 1u); } } } threadgroup_barrier(mem_flags::mem_threadgroup); { t = (t + 1u); } } } { uint innerRow = 0u; while (true) { if (!((innerRow < RowPerThread))) { break; } { uint innerCol = 0u; while (true) { if (!((innerCol < ColPerThread))) { break; } uint const index = ((innerRow * ColPerThread) + innerCol); mm_write(uniforms, resultMatrix, (globalRow + innerRow), (globalCol + innerCol), acc.arr[index]); { innerCol = (innerCol + 1u); } } } { innerRow = (innerRow + 1u); } } } return; }