121 lines
5.6 KiB
Plaintext
121 lines
5.6 KiB
Plaintext
#include <metal_stdlib>
|
|
|
|
using namespace metal;
|
|
struct Uniforms {
|
|
/* 0x0000 */ uint dimAOuter;
|
|
/* 0x0004 */ uint dimInner;
|
|
/* 0x0008 */ uint dimBOuter;
|
|
};
|
|
struct Matrix {
|
|
/* 0x0000 */ float numbers[1];
|
|
};
|
|
struct tint_array_wrapper_1 {
|
|
float arr[64];
|
|
};
|
|
struct tint_array_wrapper {
|
|
tint_array_wrapper_1 arr[64];
|
|
};
|
|
struct tint_array_wrapper_2 {
|
|
float arr[16];
|
|
};
|
|
struct tint_array_wrapper_3 {
|
|
float arr[4];
|
|
};
|
|
|
|
constant uint RowPerThread = 4u;
|
|
constant uint ColPerThread = 4u;
|
|
constant uint TileAOuter = 64u;
|
|
constant uint TileBOuter = 64u;
|
|
constant uint TileInner = 64u;
|
|
float mm_readA(uint row, uint col, const constant Uniforms* const tint_symbol_1, const device Matrix* const tint_symbol_2) {
|
|
if (((row < (*(tint_symbol_1)).dimAOuter) && (col < (*(tint_symbol_1)).dimInner))) {
|
|
float const result = (*(tint_symbol_2)).numbers[((row * (*(tint_symbol_1)).dimInner) + col)];
|
|
return result;
|
|
}
|
|
return 0.0f;
|
|
}
|
|
|
|
float mm_readB(uint row, uint col, const constant Uniforms* const tint_symbol_3, const device Matrix* const tint_symbol_4) {
|
|
if (((row < (*(tint_symbol_3)).dimInner) && (col < (*(tint_symbol_3)).dimBOuter))) {
|
|
float const result = (*(tint_symbol_4)).numbers[((row * (*(tint_symbol_3)).dimBOuter) + col)];
|
|
return result;
|
|
}
|
|
return 0.0f;
|
|
}
|
|
|
|
void mm_write(uint row, uint col, float value, const constant Uniforms* const tint_symbol_5, device Matrix* const tint_symbol_6) {
|
|
if (((row < (*(tint_symbol_5)).dimAOuter) && (col < (*(tint_symbol_5)).dimBOuter))) {
|
|
uint const index = (col + (row * (*(tint_symbol_5)).dimBOuter));
|
|
(*(tint_symbol_6)).numbers[index] = value;
|
|
}
|
|
}
|
|
|
|
void tint_symbol_inner(uint3 local_id, uint3 global_id, uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol_7, threadgroup tint_array_wrapper* const tint_symbol_8, const constant Uniforms* const tint_symbol_9, const device Matrix* const tint_symbol_10, const device Matrix* const tint_symbol_11, device Matrix* const tint_symbol_12) {
|
|
for(uint idx = local_invocation_index; (idx < 4096u); idx = (idx + 256u)) {
|
|
uint const i = (idx / 64u);
|
|
uint const i_1 = (idx % 64u);
|
|
(*(tint_symbol_7)).arr[i].arr[i_1] = float();
|
|
(*(tint_symbol_8)).arr[i].arr[i_1] = float();
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
uint const tileRow = (local_id[1] * RowPerThread);
|
|
uint const tileCol = (local_id[0] * ColPerThread);
|
|
uint const globalRow = (global_id[1] * RowPerThread);
|
|
uint const globalCol = (global_id[0] * ColPerThread);
|
|
uint const numTiles = ((((*(tint_symbol_9)).dimInner - 1u) / TileInner) + 1u);
|
|
tint_array_wrapper_2 acc = {};
|
|
float ACached = 0.0f;
|
|
tint_array_wrapper_3 BCached = {};
|
|
for(uint index = 0u; (index < (RowPerThread * ColPerThread)); index = (index + 1u)) {
|
|
acc.arr[index] = 0.0f;
|
|
}
|
|
uint const ColPerThreadA = (TileInner / 16u);
|
|
uint const tileColA = (local_id[0] * ColPerThreadA);
|
|
uint const RowPerThreadB = (TileInner / 16u);
|
|
uint const tileRowB = (local_id[1] * RowPerThreadB);
|
|
for(uint t = 0u; (t < numTiles); t = (t + 1u)) {
|
|
for(uint innerRow = 0u; (innerRow < RowPerThread); innerRow = (innerRow + 1u)) {
|
|
for(uint innerCol = 0u; (innerCol < ColPerThreadA); innerCol = (innerCol + 1u)) {
|
|
uint const inputRow = (tileRow + innerRow);
|
|
uint const inputCol = (tileColA + innerCol);
|
|
(*(tint_symbol_7)).arr[inputRow].arr[inputCol] = mm_readA((globalRow + innerRow), ((t * TileInner) + inputCol), tint_symbol_9, tint_symbol_10);
|
|
}
|
|
}
|
|
for(uint innerRow = 0u; (innerRow < RowPerThreadB); innerRow = (innerRow + 1u)) {
|
|
for(uint innerCol = 0u; (innerCol < ColPerThread); innerCol = (innerCol + 1u)) {
|
|
uint const inputRow = (tileRowB + innerRow);
|
|
uint const inputCol = (tileCol + innerCol);
|
|
(*(tint_symbol_8)).arr[innerCol].arr[inputCol] = mm_readB(((t * TileInner) + inputRow), (globalCol + innerCol), tint_symbol_9, tint_symbol_11);
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for(uint k = 0u; (k < TileInner); k = (k + 1u)) {
|
|
for(uint inner = 0u; (inner < ColPerThread); inner = (inner + 1u)) {
|
|
BCached.arr[inner] = (*(tint_symbol_8)).arr[k].arr[(tileCol + inner)];
|
|
}
|
|
for(uint innerRow = 0u; (innerRow < RowPerThread); innerRow = (innerRow + 1u)) {
|
|
ACached = (*(tint_symbol_7)).arr[(tileRow + innerRow)].arr[k];
|
|
for(uint innerCol = 0u; (innerCol < ColPerThread); innerCol = (innerCol + 1u)) {
|
|
uint const index = ((innerRow * ColPerThread) + innerCol);
|
|
acc.arr[index] = (acc.arr[index] + (ACached * BCached.arr[innerCol]));
|
|
}
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
for(uint innerRow = 0u; (innerRow < RowPerThread); innerRow = (innerRow + 1u)) {
|
|
for(uint innerCol = 0u; (innerCol < ColPerThread); innerCol = (innerCol + 1u)) {
|
|
uint const index = ((innerRow * ColPerThread) + innerCol);
|
|
mm_write((globalRow + innerRow), (globalCol + innerCol), acc.arr[index], tint_symbol_9, tint_symbol_12);
|
|
}
|
|
}
|
|
}
|
|
|
|
kernel void tint_symbol(const constant Uniforms* tint_symbol_15 [[buffer(0)]], const device Matrix* tint_symbol_16 [[buffer(2)]], const device Matrix* tint_symbol_17 [[buffer(3)]], device Matrix* tint_symbol_18 [[buffer(1)]], uint3 local_id [[thread_position_in_threadgroup]], uint3 global_id [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
|
threadgroup tint_array_wrapper tint_symbol_13;
|
|
threadgroup tint_array_wrapper tint_symbol_14;
|
|
tint_symbol_inner(local_id, global_id, local_invocation_index, &(tint_symbol_13), &(tint_symbol_14), tint_symbol_15, tint_symbol_16, tint_symbol_17, tint_symbol_18);
|
|
return;
|
|
}
|
|
|