struct Uniforms { NAN : f32; [[size(12)]] padding : u32; aShape : vec3; [[size(4)]] padding_1 : u32; bShape : vec3; [[size(4)]] padding_2 : u32; outShape : vec3; [[size(4)]] padding_3 : u32; outShapeStrides : vec2; }; type RTArr = [[stride(4)]] array; type RTArr_1 = [[stride(4)]] array; struct ssbOut { result : RTArr_1; }; type RTArr_2 = [[stride(4)]] array; struct ssbA { A : RTArr_1; }; struct ssbB { B : RTArr_1; }; var dimAOuter_1 : i32; [[group(0), binding(3)]] var x_48 : Uniforms; var dimInner_1 : i32; var dimBOuter_1 : i32; [[group(0), binding(0)]] var x_54 : ssbOut; var gl_LocalInvocationID : vec3; var gl_GlobalInvocationID : vec3; var mm_Asub : array, 64u>; var mm_Bsub : array, 64u>; [[group(0), binding(1)]] var x_165 : ssbA; var batch : i32; [[group(0), binding(2)]] var x_185 : ssbB; fn coordsInBounds_vi2_vi2_(coord : ptr>, shape : ptr>) -> bool { var x_87 : bool; var x_88_phi : bool; let x_76 : vec2 = *(coord); let x_81 : bool = all((x_76 >= vec2(0, 0))); x_88_phi = x_81; if (x_81) { let x_84 : vec2 = *(coord); let x_85 : vec2 = *(shape); x_87 = all((x_84 < x_85)); x_88_phi = x_87; } let x_88 : bool = x_88_phi; return x_88; } fn mm_readA_i1_i1_(row : ptr, col : ptr) -> f32 { var batchASize : i32; var param_10 : vec2; var param_11 : vec2; var x_430 : f32; let x_417 : i32 = x_48.aShape.y; let x_419 : i32 = x_48.aShape.z; batchASize = (x_417 * x_419); let x_421 : i32 = *(row); let x_422 : i32 = *(col); let x_424 : i32 = dimAOuter_1; let x_425 : i32 = dimInner_1; param_10 = vec2(x_421, x_422); param_11 = vec2(x_424, x_425); let x_429 : bool = coordsInBounds_vi2_vi2_(&(param_10), &(param_11)); if (x_429) { let x_438 : i32 = batch; let x_439 : i32 = batchASize; let x_441 : i32 = *(row); let x_442 : i32 = dimInner_1; let x_445 : i32 = *(col); let x_448 : f32 = x_165.A[(((x_438 * x_439) + (x_441 * x_442)) + x_445)]; x_430 = x_448; } else { x_430 = 0.0; } let x_450 : f32 = x_430; return x_450; } fn mm_readB_i1_i1_(row_1 : ptr, col_1 : ptr) -> f32 { var batchBSize : i32; var param_12 : vec2; var param_13 : vec2; var x_468 : f32; let x_455 : i32 = x_48.bShape.y; let x_457 : i32 = x_48.bShape.z; batchBSize = (x_455 * x_457); let x_459 : i32 = *(row_1); let x_460 : i32 = *(col_1); let x_462 : i32 = dimInner_1; let x_463 : i32 = dimBOuter_1; param_12 = vec2(x_459, x_460); param_13 = vec2(x_462, x_463); let x_467 : bool = coordsInBounds_vi2_vi2_(&(param_12), &(param_13)); if (x_467) { let x_475 : i32 = batch; let x_476 : i32 = batchBSize; let x_478 : i32 = *(row_1); let x_479 : i32 = dimBOuter_1; let x_482 : i32 = *(col_1); let x_485 : f32 = x_185.B[(((x_475 * x_476) + (x_478 * x_479)) + x_482)]; x_468 = x_485; } else { x_468 = 0.0; } let x_487 : f32 = x_468; return x_487; } fn getOutputFlatIndex_vi3_(coords : ptr>) -> i32 { let x_99 : vec3 = *(coords); let x_105 : i32 = x_48.outShapeStrides.x; let x_107 : i32 = x_48.outShapeStrides.y; return i32(dot(vec3(x_99), vec3(vec3(x_105, x_107, 1)))); } fn setOutput_i1_f1_(flatIndex : ptr, value : ptr) { let x_95 : i32 = *(flatIndex); let x_96 : f32 = *(value); x_54.result[x_95] = x_96; return; } fn setOutput_i1_i1_i1_f1_(d0 : ptr, d1 : ptr, d2 : ptr, value_1 : ptr) { var flatIndex_1 : i32; var param : vec3; var param_1 : i32; var param_2 : f32; let x_115 : i32 = *(d0); let x_116 : i32 = *(d1); let x_117 : i32 = *(d2); param = vec3(x_115, x_116, x_117); let x_120 : i32 = getOutputFlatIndex_vi3_(&(param)); flatIndex_1 = x_120; let x_122 : i32 = flatIndex_1; param_1 = x_122; let x_124 : f32 = *(value_1); param_2 = x_124; setOutput_i1_f1_(&(param_1), &(param_2)); return; } fn mm_write_i1_i1_f1_(row_2 : ptr, col_2 : ptr, value_2 : ptr) { var outCoord : vec3; var param_14 : i32; var param_15 : i32; var param_16 : i32; var param_17 : f32; let x_491 : i32 = batch; let x_492 : i32 = *(row_2); let x_493 : i32 = *(col_2); outCoord = vec3(x_491, x_492, x_493); let x_496 : i32 = batch; param_14 = x_496; let x_498 : i32 = *(row_2); param_15 = x_498; let x_500 : i32 = *(col_2); param_16 = x_500; let x_502 : f32 = *(value_2); param_17 = x_502; setOutput_i1_i1_i1_f1_(&(param_14), &(param_15), &(param_16), &(param_17)); return; } fn mm_matMul_i1_i1_i1_(dimAOuter : ptr, dimInner : ptr, dimBOuter : ptr) { var tileRow : i32; var tileCol : i32; var globalRow : i32; var globalCol : i32; var numTiles : i32; var innerRow : i32; var innerCol : i32; var acc : array, 1u>; var tileColA : i32; var tileRowB : i32; var t : i32; var innerRow_1 : i32; var innerCol_1 : i32; var inputRow : i32; var inputCol : i32; var param_3 : i32; var param_4 : i32; var innerRow_2 : i32; var innerCol_2 : i32; var inputRow_1 : i32; var inputCol_1 : i32; var param_5 : i32; var param_6 : i32; var k : i32; var inner : i32; var BCached : array; var innerRow_3 : i32; var ACached : f32; var innerCol_3 : i32; var innerRow_4 : i32; var innerCol_4 : i32; var param_7 : i32; var param_8 : i32; var param_9 : f32; let x_132 : u32 = gl_LocalInvocationID.y; tileRow = (bitcast(x_132) * 1); let x_137 : u32 = gl_LocalInvocationID.x; tileCol = (bitcast(x_137) * 1); let x_143 : u32 = gl_GlobalInvocationID.y; globalRow = (bitcast(x_143) * 1); let x_148 : u32 = gl_GlobalInvocationID.x; globalCol = (bitcast(x_148) * 1); let x_152 : i32 = *(dimInner); numTiles = (((x_152 - 1) / 64) + 1); innerRow = 0; loop { let x_163 : i32 = innerRow; if ((x_163 < 1)) { } else { break; } innerCol = 0; loop { let x_171 : i32 = innerCol; if ((x_171 < 1)) { } else { break; } let x_177 : i32 = innerRow; let x_178 : i32 = innerCol; acc[x_177][x_178] = 0.0; continuing { let x_181 : i32 = innerCol; innerCol = (x_181 + 1); } } continuing { let x_183 : i32 = innerRow; innerRow = (x_183 + 1); } } let x_187 : u32 = gl_LocalInvocationID.x; tileColA = (bitcast(x_187) * 64); let x_192 : u32 = gl_LocalInvocationID.y; tileRowB = (bitcast(x_192) * 1); t = 0; loop { let x_201 : i32 = t; let x_202 : i32 = numTiles; if ((x_201 < x_202)) { } else { break; } innerRow_1 = 0; loop { let x_210 : i32 = innerRow_1; if ((x_210 < 1)) { } else { break; } innerCol_1 = 0; loop { let x_218 : i32 = innerCol_1; if ((x_218 < 64)) { } else { break; } let x_221 : i32 = tileRow; let x_222 : i32 = innerRow_1; inputRow = (x_221 + x_222); let x_225 : i32 = tileColA; let x_226 : i32 = innerCol_1; inputCol = (x_225 + x_226); let x_233 : i32 = inputRow; let x_234 : i32 = inputCol; let x_235 : i32 = globalRow; let x_236 : i32 = innerRow_1; let x_238 : i32 = t; let x_240 : i32 = inputCol; param_3 = (x_235 + x_236); param_4 = ((x_238 * 64) + x_240); let x_244 : f32 = mm_readA_i1_i1_(&(param_3), &(param_4)); mm_Asub[x_233][x_234] = x_244; continuing { let x_247 : i32 = innerCol_1; innerCol_1 = (x_247 + 1); } } continuing { let x_249 : i32 = innerRow_1; innerRow_1 = (x_249 + 1); } } innerRow_2 = 0; loop { let x_257 : i32 = innerRow_2; if ((x_257 < 1)) { } else { break; } innerCol_2 = 0; loop { let x_265 : i32 = innerCol_2; if ((x_265 < 1)) { } else { break; } let x_268 : i32 = tileRowB; let x_269 : i32 = innerRow_2; inputRow_1 = (x_268 + x_269); let x_272 : i32 = tileCol; let x_273 : i32 = innerCol_2; inputCol_1 = (x_272 + x_273); let x_278 : i32 = inputRow_1; let x_279 : i32 = inputCol_1; let x_280 : i32 = t; let x_282 : i32 = inputRow_1; let x_284 : i32 = globalCol; let x_285 : i32 = innerCol_2; param_5 = ((x_280 * 64) + x_282); param_6 = (x_284 + x_285); let x_289 : f32 = mm_readB_i1_i1_(&(param_5), &(param_6)); mm_Bsub[x_278][x_279] = x_289; continuing { let x_291 : i32 = innerCol_2; innerCol_2 = (x_291 + 1); } } continuing { let x_293 : i32 = innerRow_2; innerRow_2 = (x_293 + 1); } } workgroupBarrier(); k = 0; loop { let x_302 : i32 = k; if ((x_302 < 64)) { } else { break; } inner = 0; loop { let x_310 : i32 = inner; if ((x_310 < 1)) { } else { break; } let x_314 : i32 = inner; let x_315 : i32 = k; let x_316 : i32 = tileCol; let x_317 : i32 = inner; let x_320 : f32 = mm_Bsub[x_315][(x_316 + x_317)]; BCached[x_314] = x_320; continuing { let x_322 : i32 = inner; inner = (x_322 + 1); } } innerRow_3 = 0; loop { let x_330 : i32 = innerRow_3; if ((x_330 < 1)) { } else { break; } let x_333 : i32 = tileRow; let x_334 : i32 = innerRow_3; let x_336 : i32 = k; let x_338 : f32 = mm_Asub[(x_333 + x_334)][x_336]; ACached = x_338; innerCol_3 = 0; loop { let x_345 : i32 = innerCol_3; if ((x_345 < 1)) { } else { break; } let x_347 : i32 = innerRow_3; let x_348 : i32 = innerCol_3; let x_349 : f32 = ACached; let x_350 : i32 = innerCol_3; let x_352 : f32 = BCached[x_350]; let x_355 : f32 = acc[x_347][x_348]; acc[x_347][x_348] = (x_355 + (x_349 * x_352)); continuing { let x_358 : i32 = innerCol_3; innerCol_3 = (x_358 + 1); } } continuing { let x_360 : i32 = innerRow_3; innerRow_3 = (x_360 + 1); } } continuing { let x_362 : i32 = k; k = (x_362 + 1); } } workgroupBarrier(); continuing { let x_364 : i32 = t; t = (x_364 + 1); } } innerRow_4 = 0; loop { let x_372 : i32 = innerRow_4; if ((x_372 < 1)) { } else { break; } innerCol_4 = 0; loop { var x_393 : bool; var x_394_phi : bool; let x_380 : i32 = innerCol_4; if ((x_380 < 1)) { } else { break; } let x_382 : i32 = globalCol; let x_383 : i32 = innerCol_4; let x_385 : i32 = *(dimBOuter); let x_386 : bool = ((x_382 + x_383) < x_385); x_394_phi = x_386; if (x_386) { let x_389 : i32 = globalRow; let x_390 : i32 = innerRow_4; let x_392 : i32 = *(dimAOuter); x_393 = ((x_389 + x_390) < x_392); x_394_phi = x_393; } let x_394 : bool = x_394_phi; if (x_394) { let x_397 : i32 = globalRow; let x_398 : i32 = innerRow_4; let x_400 : i32 = globalCol; let x_401 : i32 = innerCol_4; let x_403 : i32 = innerRow_4; let x_404 : i32 = innerCol_4; param_7 = (x_397 + x_398); param_8 = (x_400 + x_401); let x_409 : f32 = acc[x_403][x_404]; param_9 = x_409; mm_write_i1_i1_f1_(&(param_7), &(param_8), &(param_9)); } continuing { let x_411 : i32 = innerCol_4; innerCol_4 = (x_411 + 1); } } continuing { let x_413 : i32 = innerRow_4; innerRow_4 = (x_413 + 1); } } return; } fn main_1() { var param_18 : i32; var param_19 : i32; var param_20 : i32; let x_67 : i32 = x_48.aShape.y; dimAOuter_1 = x_67; let x_71 : i32 = x_48.aShape.z; dimInner_1 = x_71; let x_75 : i32 = x_48.bShape.z; dimBOuter_1 = x_75; let x_505 : u32 = gl_GlobalInvocationID.z; batch = bitcast(x_505); let x_508 : i32 = dimAOuter_1; param_18 = x_508; let x_510 : i32 = dimInner_1; param_19 = x_510; let x_512 : i32 = dimBOuter_1; param_20 = x_512; mm_matMul_i1_i1_i1_(&(param_18), &(param_19), &(param_20)); return; } [[stage(compute), workgroup_size(1, 64, 1)]] fn main([[builtin(local_invocation_id)]] gl_LocalInvocationID_param : vec3, [[builtin(global_invocation_id)]] gl_GlobalInvocationID_param : vec3) { gl_LocalInvocationID = gl_LocalInvocationID_param; gl_GlobalInvocationID = gl_GlobalInvocationID_param; main_1(); }