#include using namespace metal; struct Uniforms { /* 0x0000 */ float tint_symbol; /* 0x0004 */ int8_t tint_pad[12]; /* 0x0010 */ packed_int3 aShape; /* 0x001c */ int8_t tint_pad_1[4]; /* 0x0020 */ packed_int3 bShape; /* 0x002c */ int8_t tint_pad_2[4]; /* 0x0030 */ packed_int3 outShape; /* 0x003c */ int8_t tint_pad_3[4]; /* 0x0040 */ int2 outShapeStrides; }; struct ssbOut { /* 0x0000 */ float result[1]; }; struct ssbA { /* 0x0000 */ float A[1]; }; struct ssbB { /* 0x0000 */ float B[1]; }; struct tint_array_wrapper_1 { float arr[64]; }; struct tint_array_wrapper { tint_array_wrapper_1 arr[64]; }; struct tint_array_wrapper_3 { float arr[1]; }; struct tint_array_wrapper_2 { tint_array_wrapper_3 arr[64]; }; struct tint_array_wrapper_4 { tint_array_wrapper_3 arr[1]; }; bool coordsInBounds_vi2_vi2_(thread int2* const coord, thread int2* const shape) { bool x_87 = false; bool x_88_phi = false; int2 const x_76 = *(coord); bool const x_81 = all((x_76 >= int2(0, 0))); x_88_phi = x_81; if (x_81) { int2 const x_84 = *(coord); int2 const x_85 = *(shape); x_87 = all((x_84 < x_85)); x_88_phi = x_87; } bool const x_88 = x_88_phi; return x_88; } float mm_readA_i1_i1_(constant Uniforms& x_48, const device ssbA& x_165, thread int* const row, thread int* const col, thread int* const tint_symbol_2, thread int* const tint_symbol_3, thread int* const tint_symbol_4) { int batchASize = 0; int2 param_10 = 0; int2 param_11 = 0; float x_430 = 0.0f; int const x_417 = x_48.aShape.y; int const x_419 = x_48.aShape.z; batchASize = as_type((as_type(x_417) * as_type(x_419))); int const x_421 = *(row); int const x_422 = *(col); int const x_424 = *(tint_symbol_2); int const x_425 = *(tint_symbol_3); param_10 = int2(x_421, x_422); param_11 = int2(x_424, x_425); bool const x_429 = coordsInBounds_vi2_vi2_(&(param_10), &(param_11)); if (x_429) { int const x_438 = *(tint_symbol_4); int const x_439 = batchASize; int const x_441 = *(row); int const x_442 = *(tint_symbol_3); int const x_445 = *(col); float const x_448 = x_165.A[as_type((as_type(as_type((as_type(as_type((as_type(x_438) * as_type(x_439)))) + as_type(as_type((as_type(x_441) * as_type(x_442))))))) + as_type(x_445)))]; x_430 = x_448; } else { x_430 = 0.0f; } float const x_450 = x_430; return x_450; } float mm_readB_i1_i1_(constant Uniforms& x_48, const device ssbB& x_185, thread int* const row_1, thread int* const col_1, thread int* const tint_symbol_5, thread int* const tint_symbol_6, thread int* const tint_symbol_7) { int batchBSize = 0; int2 param_12 = 0; int2 param_13 = 0; float x_468 = 0.0f; int const x_455 = x_48.bShape.y; int const x_457 = x_48.bShape.z; batchBSize = as_type((as_type(x_455) * as_type(x_457))); int const x_459 = *(row_1); int const x_460 = *(col_1); int const x_462 = *(tint_symbol_5); int const x_463 = *(tint_symbol_6); param_12 = int2(x_459, x_460); param_13 = int2(x_462, x_463); bool const x_467 = coordsInBounds_vi2_vi2_(&(param_12), &(param_13)); if (x_467) { int const x_475 = *(tint_symbol_7); int const x_476 = batchBSize; int const x_478 = *(row_1); int const x_479 = *(tint_symbol_6); int const x_482 = *(col_1); float const x_485 = x_185.B[as_type((as_type(as_type((as_type(as_type((as_type(x_475) * as_type(x_476)))) + as_type(as_type((as_type(x_478) * as_type(x_479))))))) + as_type(x_482)))]; x_468 = x_485; } else { x_468 = 0.0f; } float const x_487 = x_468; return x_487; } int getOutputFlatIndex_vi3_(constant Uniforms& x_48, thread int3* const coords) { int3 const x_99 = *(coords); int const x_105 = x_48.outShapeStrides.x; int const x_107 = x_48.outShapeStrides.y; return int(dot(float3(x_99), float3(int3(x_105, x_107, 1)))); } void setOutput_i1_f1_(device ssbOut& x_54, thread int* const flatIndex, thread float* const value) { int const x_95 = *(flatIndex); float const x_96 = *(value); x_54.result[x_95] = x_96; return; } void setOutput_i1_i1_i1_f1_(constant Uniforms& x_48, device ssbOut& x_54, thread int* const d0, thread int* const d1, thread int* const d2, thread float* const value_1) { int flatIndex_1 = 0; int3 param = 0; int param_1 = 0; float param_2 = 0.0f; int const x_115 = *(d0); int const x_116 = *(d1); int const x_117 = *(d2); param = int3(x_115, x_116, x_117); int const x_120 = getOutputFlatIndex_vi3_(x_48, &(param)); flatIndex_1 = x_120; int const x_122 = flatIndex_1; param_1 = x_122; float const x_124 = *(value_1); param_2 = x_124; setOutput_i1_f1_(x_54, &(param_1), &(param_2)); return; } void mm_write_i1_i1_f1_(constant Uniforms& x_48, device ssbOut& x_54, thread int* const row_2, thread int* const col_2, thread float* const value_2, thread int* const tint_symbol_8) { int3 outCoord = 0; int param_14 = 0; int param_15 = 0; int param_16 = 0; float param_17 = 0.0f; int const x_491 = *(tint_symbol_8); int const x_492 = *(row_2); int const x_493 = *(col_2); outCoord = int3(x_491, x_492, x_493); int const x_496 = *(tint_symbol_8); param_14 = x_496; int const x_498 = *(row_2); param_15 = x_498; int const x_500 = *(col_2); param_16 = x_500; float const x_502 = *(value_2); param_17 = x_502; setOutput_i1_i1_i1_f1_(x_48, x_54, &(param_14), &(param_15), &(param_16), &(param_17)); return; } void mm_matMul_i1_i1_i1_(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, thread int* const dimAOuter, thread int* const dimInner, thread int* const dimBOuter, thread uint3* const tint_symbol_9, thread uint3* const tint_symbol_10, thread int* const tint_symbol_11, thread int* const tint_symbol_12, thread int* const tint_symbol_13, threadgroup tint_array_wrapper* const tint_symbol_14, thread int* const tint_symbol_15, threadgroup tint_array_wrapper_2* const tint_symbol_16) { int tileRow = 0; int tileCol = 0; int globalRow = 0; int globalCol = 0; int numTiles = 0; int innerRow = 0; int innerCol = 0; tint_array_wrapper_4 acc = {}; int tileColA = 0; int tileRowB = 0; int t = 0; int innerRow_1 = 0; int innerCol_1 = 0; int inputRow = 0; int inputCol = 0; int param_3 = 0; int param_4 = 0; int innerRow_2 = 0; int innerCol_2 = 0; int inputRow_1 = 0; int inputCol_1 = 0; int param_5 = 0; int param_6 = 0; int k = 0; int inner = 0; tint_array_wrapper_3 BCached = {}; int innerRow_3 = 0; float ACached = 0.0f; int innerCol_3 = 0; int innerRow_4 = 0; int innerCol_4 = 0; int param_7 = 0; int param_8 = 0; float param_9 = 0.0f; uint const x_132 = (*(tint_symbol_9)).y; tileRow = as_type((as_type(as_type(x_132)) * as_type(1))); uint const x_137 = (*(tint_symbol_9)).x; tileCol = as_type((as_type(as_type(x_137)) * as_type(1))); uint const x_143 = (*(tint_symbol_10)).y; globalRow = as_type((as_type(as_type(x_143)) * as_type(1))); uint const x_148 = (*(tint_symbol_10)).x; globalCol = as_type((as_type(as_type(x_148)) * as_type(1))); int const x_152 = *(dimInner); numTiles = as_type((as_type((as_type((as_type(x_152) - as_type(1))) / 64)) + as_type(1))); innerRow = 0; while (true) { int const x_163 = innerRow; if ((x_163 < 1)) { } else { break; } innerCol = 0; while (true) { int const x_171 = innerCol; if ((x_171 < 1)) { } else { break; } int const x_177 = innerRow; int const x_178 = innerCol; acc.arr[x_177].arr[x_178] = 0.0f; { int const x_181 = innerCol; innerCol = as_type((as_type(x_181) + as_type(1))); } } { int const x_183 = innerRow; innerRow = as_type((as_type(x_183) + as_type(1))); } } uint const x_187 = (*(tint_symbol_9)).x; tileColA = as_type((as_type(as_type(x_187)) * as_type(64))); uint const x_192 = (*(tint_symbol_9)).y; tileRowB = as_type((as_type(as_type(x_192)) * as_type(1))); t = 0; while (true) { int const x_201 = t; int const x_202 = numTiles; if ((x_201 < x_202)) { } else { break; } innerRow_1 = 0; while (true) { int const x_210 = innerRow_1; if ((x_210 < 1)) { } else { break; } innerCol_1 = 0; while (true) { int const x_218 = innerCol_1; if ((x_218 < 64)) { } else { break; } int const x_221 = tileRow; int const x_222 = innerRow_1; inputRow = as_type((as_type(x_221) + as_type(x_222))); int const x_225 = tileColA; int const x_226 = innerCol_1; inputCol = as_type((as_type(x_225) + as_type(x_226))); int const x_233 = inputRow; int const x_234 = inputCol; int const x_235 = globalRow; int const x_236 = innerRow_1; int const x_238 = t; int const x_240 = inputCol; param_3 = as_type((as_type(x_235) + as_type(x_236))); param_4 = as_type((as_type(as_type((as_type(x_238) * as_type(64)))) + as_type(x_240))); float const x_244 = mm_readA_i1_i1_(x_48, x_165, &(param_3), &(param_4), tint_symbol_11, tint_symbol_12, tint_symbol_13); (*(tint_symbol_14)).arr[x_233].arr[x_234] = x_244; { int const x_247 = innerCol_1; innerCol_1 = as_type((as_type(x_247) + as_type(1))); } } { int const x_249 = innerRow_1; innerRow_1 = as_type((as_type(x_249) + as_type(1))); } } innerRow_2 = 0; while (true) { int const x_257 = innerRow_2; if ((x_257 < 1)) { } else { break; } innerCol_2 = 0; while (true) { int const x_265 = innerCol_2; if ((x_265 < 1)) { } else { break; } int const x_268 = tileRowB; int const x_269 = innerRow_2; inputRow_1 = as_type((as_type(x_268) + as_type(x_269))); int const x_272 = tileCol; int const x_273 = innerCol_2; inputCol_1 = as_type((as_type(x_272) + as_type(x_273))); int const x_278 = inputRow_1; int const x_279 = inputCol_1; int const x_280 = t; int const x_282 = inputRow_1; int const x_284 = globalCol; int const x_285 = innerCol_2; param_5 = as_type((as_type(as_type((as_type(x_280) * as_type(64)))) + as_type(x_282))); param_6 = as_type((as_type(x_284) + as_type(x_285))); float const x_289 = mm_readB_i1_i1_(x_48, x_185, &(param_5), &(param_6), tint_symbol_12, tint_symbol_15, tint_symbol_13); (*(tint_symbol_16)).arr[x_278].arr[x_279] = x_289; { int const x_291 = innerCol_2; innerCol_2 = as_type((as_type(x_291) + as_type(1))); } } { int const x_293 = innerRow_2; innerRow_2 = as_type((as_type(x_293) + as_type(1))); } } threadgroup_barrier(mem_flags::mem_threadgroup); k = 0; while (true) { int const x_302 = k; if ((x_302 < 64)) { } else { break; } inner = 0; while (true) { int const x_310 = inner; if ((x_310 < 1)) { } else { break; } int const x_314 = inner; int const x_315 = k; int const x_316 = tileCol; int const x_317 = inner; float const x_320 = (*(tint_symbol_16)).arr[x_315].arr[as_type((as_type(x_316) + as_type(x_317)))]; BCached.arr[x_314] = x_320; { int const x_322 = inner; inner = as_type((as_type(x_322) + as_type(1))); } } innerRow_3 = 0; while (true) { int const x_330 = innerRow_3; if ((x_330 < 1)) { } else { break; } int const x_333 = tileRow; int const x_334 = innerRow_3; int const x_336 = k; float const x_338 = (*(tint_symbol_14)).arr[as_type((as_type(x_333) + as_type(x_334)))].arr[x_336]; ACached = x_338; innerCol_3 = 0; while (true) { int const x_345 = innerCol_3; if ((x_345 < 1)) { } else { break; } int const x_347 = innerRow_3; int const x_348 = innerCol_3; float const x_349 = ACached; int const x_350 = innerCol_3; float const x_352 = BCached.arr[x_350]; float const x_355 = acc.arr[x_347].arr[x_348]; acc.arr[x_347].arr[x_348] = (x_355 + (x_349 * x_352)); { int const x_358 = innerCol_3; innerCol_3 = as_type((as_type(x_358) + as_type(1))); } } { int const x_360 = innerRow_3; innerRow_3 = as_type((as_type(x_360) + as_type(1))); } } { int const x_362 = k; k = as_type((as_type(x_362) + as_type(1))); } } threadgroup_barrier(mem_flags::mem_threadgroup); { int const x_364 = t; t = as_type((as_type(x_364) + as_type(1))); } } innerRow_4 = 0; while (true) { int const x_372 = innerRow_4; if ((x_372 < 1)) { } else { break; } innerCol_4 = 0; while (true) { bool x_393 = false; bool x_394_phi = false; int const x_380 = innerCol_4; if ((x_380 < 1)) { } else { break; } int const x_382 = globalCol; int const x_383 = innerCol_4; int const x_385 = *(dimBOuter); bool const x_386 = (as_type((as_type(x_382) + as_type(x_383))) < x_385); x_394_phi = x_386; if (x_386) { int const x_389 = globalRow; int const x_390 = innerRow_4; int const x_392 = *(dimAOuter); x_393 = (as_type((as_type(x_389) + as_type(x_390))) < x_392); x_394_phi = x_393; } bool const x_394 = x_394_phi; if (x_394) { int const x_397 = globalRow; int const x_398 = innerRow_4; int const x_400 = globalCol; int const x_401 = innerCol_4; int const x_403 = innerRow_4; int const x_404 = innerCol_4; param_7 = as_type((as_type(x_397) + as_type(x_398))); param_8 = as_type((as_type(x_400) + as_type(x_401))); float const x_409 = acc.arr[x_403].arr[x_404]; param_9 = x_409; mm_write_i1_i1_f1_(x_48, x_54, &(param_7), &(param_8), &(param_9), tint_symbol_13); } { int const x_411 = innerCol_4; innerCol_4 = as_type((as_type(x_411) + as_type(1))); } } { int const x_413 = innerRow_4; innerRow_4 = as_type((as_type(x_413) + as_type(1))); } } return; } void main_1(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, thread int* const tint_symbol_17, thread int* const tint_symbol_18, thread int* const tint_symbol_19, thread uint3* const tint_symbol_20, thread int* const tint_symbol_21, thread uint3* const tint_symbol_22, threadgroup tint_array_wrapper* const tint_symbol_23, threadgroup tint_array_wrapper_2* const tint_symbol_24) { int param_18 = 0; int param_19 = 0; int param_20 = 0; int const x_67 = x_48.aShape.y; *(tint_symbol_17) = x_67; int const x_71 = x_48.aShape.z; *(tint_symbol_18) = x_71; int const x_75 = x_48.bShape.z; *(tint_symbol_19) = x_75; uint const x_505 = (*(tint_symbol_20)).z; *(tint_symbol_21) = as_type(x_505); int const x_508 = *(tint_symbol_17); param_18 = x_508; int const x_510 = *(tint_symbol_18); param_19 = x_510; int const x_512 = *(tint_symbol_19); param_20 = x_512; mm_matMul_i1_i1_i1_(x_48, x_165, x_185, x_54, &(param_18), &(param_19), &(param_20), tint_symbol_22, tint_symbol_20, tint_symbol_17, tint_symbol_18, tint_symbol_21, tint_symbol_23, tint_symbol_19, tint_symbol_24); return; } void tint_symbol_1_inner(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, uint3 gl_LocalInvocationID_param, uint3 gl_GlobalInvocationID_param, uint local_invocation_index, threadgroup tint_array_wrapper_2* const tint_symbol_25, threadgroup tint_array_wrapper* const tint_symbol_26, thread uint3* const tint_symbol_27, thread uint3* const tint_symbol_28, thread int* const tint_symbol_29, thread int* const tint_symbol_30, thread int* const tint_symbol_31, thread int* const tint_symbol_32) { { uint const i_1 = local_invocation_index; uint const i_2 = (local_invocation_index % 1u); (*(tint_symbol_25)).arr[i_1].arr[i_2] = float(); } for(uint idx = local_invocation_index; (idx < 4096u); idx = (idx + 64u)) { uint const i = (idx / 64u); uint const i_1 = (idx % 64u); (*(tint_symbol_26)).arr[i].arr[i_1] = float(); } threadgroup_barrier(mem_flags::mem_threadgroup); *(tint_symbol_27) = gl_LocalInvocationID_param; *(tint_symbol_28) = gl_GlobalInvocationID_param; main_1(x_48, x_165, x_185, x_54, tint_symbol_29, tint_symbol_30, tint_symbol_31, tint_symbol_28, tint_symbol_32, tint_symbol_27, tint_symbol_26, tint_symbol_25); } kernel void tint_symbol_1(uint3 gl_LocalInvocationID_param [[thread_position_in_threadgroup]], uint3 gl_GlobalInvocationID_param [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Uniforms& x_48 [[buffer(0)]], const device ssbA& x_165 [[buffer(2)]], const device ssbB& x_185 [[buffer(3)]], device ssbOut& x_54 [[buffer(1)]]) { threadgroup tint_array_wrapper_2 tint_symbol_33; threadgroup tint_array_wrapper tint_symbol_34; thread uint3 tint_symbol_35 = 0u; thread uint3 tint_symbol_36 = 0u; thread int tint_symbol_37 = 0; thread int tint_symbol_38 = 0; thread int tint_symbol_39 = 0; thread int tint_symbol_40 = 0; tint_symbol_1_inner(x_48, x_165, x_185, x_54, gl_LocalInvocationID_param, gl_GlobalInvocationID_param, local_invocation_index, &(tint_symbol_33), &(tint_symbol_34), &(tint_symbol_35), &(tint_symbol_36), &(tint_symbol_37), &(tint_symbol_38), &(tint_symbol_39), &(tint_symbol_40)); return; }