124 lines
4.4 KiB
WebGPU Shading Language
124 lines
4.4 KiB
WebGPU Shading Language
struct Uniforms {
|
|
dimAOuter : u32,
|
|
dimInner : u32,
|
|
dimBOuter : u32,
|
|
};
|
|
struct Matrix {
|
|
numbers: array<f32>,
|
|
};
|
|
|
|
@group(0) @binding(0) var<storage, read> firstMatrix : Matrix;
|
|
@group(0) @binding(1) var<storage, read> secondMatrix : Matrix;
|
|
@group(0) @binding(2) var<storage, write> resultMatrix : Matrix;
|
|
@group(0) @binding(3) var<uniform> uniforms : Uniforms;
|
|
|
|
fn mm_readA(row : u32, col : u32) -> f32 {
|
|
if (row < uniforms.dimAOuter && col < uniforms.dimInner)
|
|
{
|
|
let result : f32 = firstMatrix.numbers[row * uniforms.dimInner + col];
|
|
return result;
|
|
}
|
|
return 0.;
|
|
}
|
|
|
|
fn mm_readB(row : u32, col : u32) -> f32 {
|
|
if (row < uniforms.dimInner && col < uniforms.dimBOuter)
|
|
{
|
|
let result : f32 = secondMatrix.numbers[row * uniforms.dimBOuter + col];
|
|
return result;
|
|
}
|
|
return 0.;
|
|
}
|
|
|
|
fn mm_write(row : u32, col : u32, value : f32) {
|
|
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
|
|
{
|
|
let index : u32 = col + row * uniforms.dimBOuter;
|
|
resultMatrix.numbers[index] = value;
|
|
}
|
|
}
|
|
|
|
let RowPerThread : u32 = 4u;
|
|
let ColPerThread : u32 = 4u;
|
|
let TileAOuter : u32 = 64u;
|
|
let TileBOuter : u32 = 64u;
|
|
let TileInner : u32 = 64u;
|
|
var<workgroup> mm_Asub : array<array<f32, 64>, 64>;
|
|
var<workgroup> mm_Bsub : array<array<f32, 64>, 64>;
|
|
@stage(compute) @workgroup_size(16, 16, 1)
|
|
fn main(@builtin(local_invocation_id) local_id : vec3<u32>,
|
|
@builtin(global_invocation_id) global_id : vec3<u32>) {
|
|
let tileRow : u32 = local_id.y * RowPerThread;
|
|
let tileCol : u32 = local_id.x * ColPerThread;
|
|
|
|
let globalRow : u32 = global_id.y * RowPerThread;
|
|
let globalCol : u32 = global_id.x * ColPerThread;
|
|
|
|
let numTiles : u32 = (uniforms.dimInner - 1u) / TileInner + 1u;
|
|
|
|
var acc: array<f32, 16>;
|
|
var ACached : f32;
|
|
var BCached : array<f32, 4>;
|
|
|
|
// Without this initialization strange values show up in acc.
|
|
// TODO: Remove it once the following bug is fixed.
|
|
// https://bugs.chromium.org/p/tint/issues/detail?id=759
|
|
for (var index : u32 = 0u; index < RowPerThread * ColPerThread; index = index + 1u) {
|
|
acc[index] = 0.;
|
|
}
|
|
|
|
let ColPerThreadA : u32 = TileInner / 16u;
|
|
let tileColA : u32 = local_id.x * ColPerThreadA;
|
|
let RowPerThreadB : u32 = TileInner / 16u;
|
|
let tileRowB : u32 = local_id.y * RowPerThreadB;
|
|
|
|
// Loop over shared dimension.
|
|
for (var t : u32 = 0u; t < numTiles; t = t + 1u) {
|
|
// Load one tile of A into local memory.
|
|
for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) {
|
|
for (var innerCol : u32 = 0u; innerCol < ColPerThreadA; innerCol = innerCol + 1u) {
|
|
let inputRow : u32 = tileRow + innerRow;
|
|
let inputCol : u32 = tileColA + innerCol;
|
|
mm_Asub[inputRow][inputCol] = mm_readA(globalRow + innerRow, t * TileInner + inputCol);
|
|
}
|
|
}
|
|
// Load one tile of B into local memory.
|
|
for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) {
|
|
for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) {
|
|
let inputRow : u32 = tileRowB + innerRow;
|
|
let inputCol : u32 = tileCol + innerCol;
|
|
|
|
mm_Bsub[innerCol][inputCol] = mm_readB(t * TileInner + inputRow, globalCol + innerCol);;
|
|
}
|
|
}
|
|
|
|
workgroupBarrier();
|
|
|
|
// Compute acc values for a single thread.
|
|
for (var k : u32 = 0u; k < TileInner; k = k + 1u) {
|
|
for (var inner : u32 = 0u; inner < ColPerThread; inner = inner + 1u) {
|
|
BCached[inner] = mm_Bsub[k][tileCol + inner];
|
|
}
|
|
|
|
for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) {
|
|
ACached = mm_Asub[tileRow + innerRow][k];
|
|
for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) {
|
|
let index : u32 = innerRow * ColPerThread + innerCol;
|
|
acc[index] = acc[index] + ACached * BCached[innerCol];
|
|
}
|
|
}
|
|
}
|
|
|
|
workgroupBarrier();
|
|
}
|
|
|
|
for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) {
|
|
for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) {
|
|
let index : u32 = innerRow * ColPerThread + innerCol;
|
|
mm_write(globalRow + innerRow,
|
|
globalCol + innerCol,
|
|
acc[index]);
|
|
}
|
|
}
|
|
}
|