test: Add bug case for tint:914
Bug: tint:914 Change-Id: Id17b675e947b170e460c415c15d5d75f311e65b0 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55247 Kokoro: Kokoro <noreply+kokoro@google.com> Auto-Submit: Ben Clayton <bclayton@google.com> Commit-Queue: James Price <jrprice@google.com> Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
parent
0f916164ae
commit
c15baf695d
|
@ -0,0 +1,70 @@
|
|||
[[block]] struct Uniforms {
|
||||
dstTextureFlipY : u32;
|
||||
channelCount : u32;
|
||||
srcCopyOrigin : vec2<u32>;
|
||||
dstCopyOrigin : vec2<u32>;
|
||||
copySize : vec2<u32>;
|
||||
};
|
||||
[[block]] struct OutputBuf {
|
||||
result : array<u32>;
|
||||
};
|
||||
[[group(0), binding(0)]] var src : texture_2d<f32>;
|
||||
[[group(0), binding(1)]] var dst : texture_2d<f32>;
|
||||
[[group(0), binding(2)]] var<storage, read_write> output : OutputBuf;
|
||||
[[group(0), binding(3)]] var<uniform> uniforms : Uniforms;
|
||||
fn aboutEqual(value : f32, expect : f32) -> bool {
|
||||
// The value diff should be smaller than the hard coded tolerance.
|
||||
return abs(value - expect) < 0.001;
|
||||
}
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
|
||||
let srcSize : vec2<i32> = textureDimensions(src);
|
||||
let dstSize : vec2<i32> = textureDimensions(dst);
|
||||
let dstTexCoord : vec2<u32> = vec2<u32>(GlobalInvocationID.xy);
|
||||
let nonCoveredColor : vec4<f32> =
|
||||
vec4<f32>(0.0, 1.0, 0.0, 1.0); // should be green
|
||||
|
||||
var success : bool = true;
|
||||
if (dstTexCoord.x < uniforms.dstCopyOrigin.x ||
|
||||
dstTexCoord.y < uniforms.dstCopyOrigin.y ||
|
||||
dstTexCoord.x >= uniforms.dstCopyOrigin.x + uniforms.copySize.x ||
|
||||
dstTexCoord.y >= uniforms.dstCopyOrigin.y + uniforms.copySize.y) {
|
||||
success = success &&
|
||||
all(textureLoad(dst, vec2<i32>(dstTexCoord), 0) == nonCoveredColor);
|
||||
} else {
|
||||
// Calculate source texture coord.
|
||||
var srcTexCoord : vec2<u32> = dstTexCoord - uniforms.dstCopyOrigin +
|
||||
uniforms.srcCopyOrigin;
|
||||
// Note that |flipY| equals flip src texture firstly and then do copy from src
|
||||
// subrect to dst subrect. This helps on blink part to handle some input texture
|
||||
// which is flipped and need to unpack flip during the copy.
|
||||
// We need to calculate the expect y coord based on this rule.
|
||||
if (uniforms.dstTextureFlipY == 1u) {
|
||||
srcTexCoord.y = u32(srcSize.y) - srcTexCoord.y - 1u;
|
||||
}
|
||||
|
||||
let srcColor : vec4<f32> = textureLoad(src, vec2<i32>(srcTexCoord), 0);
|
||||
let dstColor : vec4<f32> = textureLoad(dst, vec2<i32>(dstTexCoord), 0);
|
||||
|
||||
// Not use loop and variable index format to workaround
|
||||
// crbug.com/tint/638.
|
||||
if (uniforms.channelCount == 2u) { // All have rg components.
|
||||
success = success &&
|
||||
aboutEqual(dstColor.r, srcColor.r) &&
|
||||
aboutEqual(dstColor.g, srcColor.g);
|
||||
} else {
|
||||
success = success &&
|
||||
aboutEqual(dstColor.r, srcColor.r) &&
|
||||
aboutEqual(dstColor.g, srcColor.g) &&
|
||||
aboutEqual(dstColor.b, srcColor.b) &&
|
||||
aboutEqual(dstColor.a, srcColor.a);
|
||||
}
|
||||
}
|
||||
let outputIndex : u32 = GlobalInvocationID.y * u32(dstSize.x) +
|
||||
GlobalInvocationID.x;
|
||||
if (success) {
|
||||
output.result[outputIndex] = 1u;
|
||||
} else {
|
||||
output.result[outputIndex] = 0u;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
[[block]] struct Uniforms {
|
||||
dimAOuter : u32;
|
||||
dimInner : u32;
|
||||
dimBOuter : u32;
|
||||
};
|
||||
[[block]] struct Matrix {
|
||||
numbers: array<f32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> firstMatrix : Matrix;
|
||||
[[group(0), binding(1)]] var<storage, read> secondMatrix : Matrix;
|
||||
[[group(0), binding(2)]] var<storage, write> resultMatrix : Matrix;
|
||||
[[group(0), binding(3)]] var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn mm_readA(row : u32, col : u32) -> f32 {
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimInner)
|
||||
{
|
||||
let result : f32 = firstMatrix.numbers[row * uniforms.dimInner + col];
|
||||
return result;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
|
||||
fn mm_readB(row : u32, col : u32) -> f32 {
|
||||
if (row < uniforms.dimInner && col < uniforms.dimBOuter)
|
||||
{
|
||||
let result : f32 = secondMatrix.numbers[row * uniforms.dimBOuter + col];
|
||||
return result;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
|
||||
fn mm_write(row : u32, col : u32, value : f32) {
|
||||
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
|
||||
{
|
||||
let index : u32 = col + row * uniforms.dimBOuter;
|
||||
resultMatrix.numbers[index] = value;
|
||||
}
|
||||
}
|
||||
|
||||
let RowPerThread : u32 = 4u;
|
||||
let ColPerThread : u32 = 4u;
|
||||
let TileAOuter : u32 = 64u;
|
||||
let TileBOuter : u32 = 64u;
|
||||
let TileInner : u32 = 64u;
|
||||
var<workgroup> mm_Asub : array<array<f32, 64>, 64>;
|
||||
var<workgroup> mm_Bsub : array<array<f32, 64>, 64>;
|
||||
[[stage(compute), workgroup_size(16, 16, 1)]]
|
||||
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
|
||||
[[builtin(global_invocation_id)]] global_id : vec3<u32>) {
|
||||
let tileRow : u32 = local_id.y * RowPerThread;
|
||||
let tileCol : u32 = local_id.x * ColPerThread;
|
||||
|
||||
let globalRow : u32 = global_id.y * RowPerThread;
|
||||
let globalCol : u32 = global_id.x * ColPerThread;
|
||||
|
||||
let numTiles : u32 = (uniforms.dimInner - 1u) / TileInner + 1u;
|
||||
|
||||
var acc: array<f32, 16>;
|
||||
var ACached : f32;
|
||||
var BCached : array<f32, 4>;
|
||||
|
||||
// Without this initialization strange values show up in acc.
|
||||
// TODO: Remove it once the following bug is fixed.
|
||||
// https://bugs.chromium.org/p/tint/issues/detail?id=759
|
||||
for (var index : u32 = 0u; index < RowPerThread * ColPerThread; index = index + 1u) {
|
||||
acc[index] = 0.;
|
||||
}
|
||||
|
||||
let ColPerThreadA : u32 = TileInner / 16u;
|
||||
let tileColA : u32 = local_id.x * ColPerThreadA;
|
||||
let RowPerThreadB : u32 = TileInner / 16u;
|
||||
let tileRowB : u32 = local_id.y * RowPerThreadB;
|
||||
|
||||
// Loop over shared dimension.
|
||||
for (var t : u32 = 0u; t < numTiles; t = t + 1u) {
|
||||
// Load one tile of A into local memory.
|
||||
for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) {
|
||||
for (var innerCol : u32 = 0u; innerCol < ColPerThreadA; innerCol = innerCol + 1u) {
|
||||
let inputRow : u32 = tileRow + innerRow;
|
||||
let inputCol : u32 = tileColA + innerCol;
|
||||
mm_Asub[inputRow][inputCol] = mm_readA(globalRow + innerRow, t * TileInner + inputCol);
|
||||
}
|
||||
}
|
||||
// Load one tile of B into local memory.
|
||||
for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) {
|
||||
for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) {
|
||||
let inputRow : u32 = tileRowB + innerRow;
|
||||
let inputCol : u32 = tileCol + innerCol;
|
||||
|
||||
mm_Bsub[innerCol][inputCol] = mm_readB(t * TileInner + inputRow, globalCol + innerCol);;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute acc values for a single thread.
|
||||
for (var k : u32 = 0u; k < TileInner; k = k + 1u) {
|
||||
for (var inner : u32 = 0u; inner < ColPerThread; inner = inner + 1u) {
|
||||
BCached[inner] = mm_Bsub[k][tileCol + inner];
|
||||
}
|
||||
|
||||
for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) {
|
||||
ACached = mm_Asub[tileRow + innerRow][k];
|
||||
for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) {
|
||||
let index : u32 = innerRow * ColPerThread + innerCol;
|
||||
acc[index] = acc[index] + ACached * BCached[innerCol];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
for (var innerRow : u32 = 0u; innerRow < RowPerThread; innerRow = innerRow + 1u) {
|
||||
for (var innerCol : u32 = 0u; innerCol < ColPerThread; innerCol = innerCol + 1u) {
|
||||
let index : u32 = innerRow * ColPerThread + innerCol;
|
||||
mm_write(globalRow + innerRow,
|
||||
globalCol + innerCol,
|
||||
acc[index]);
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,78 @@
|
|||
SKIP: FAILED
|
||||
|
||||
|
||||
|
||||
Validation Failure:
|
||||
|
||||
Compilation failed:
|
||||
|
||||
program_source:56:31: warning: equality comparison with extraneous parentheses
|
||||
if ((local_invocation_index == 0u)) {
|
||||
~~~~~~~~~~~~~~~~~~~~~~~^~~~~
|
||||
program_source:56:31: note: remove extraneous parentheses around the comparison to silence this warning
|
||||
if ((local_invocation_index == 0u)) {
|
||||
~ ^ ~
|
||||
program_source:56:31: note: use '=' to turn this equality comparison into an assignment
|
||||
if ((local_invocation_index == 0u)) {
|
||||
^~
|
||||
=
|
||||
program_source:122:30: error: default initialization of an object of const type 'const uint' (aka 'const unsigned int')
|
||||
uint const inputRow;
|
||||
^
|
||||
= 0
|
||||
program_source:123:30: error: default initialization of an object of const type 'const uint' (aka 'const unsigned int')
|
||||
uint const inputCol;
|
||||
^
|
||||
= 0
|
||||
program_source:133:30: error: cannot assign to variable 'inputRow' with const-qualified type 'const uint' (aka 'const unsigned int')
|
||||
inputRow = (tileRow + innerRow);
|
||||
~~~~~~~~ ^
|
||||
program_source:122:30: note: variable 'inputRow' declared const here
|
||||
uint const inputRow;
|
||||
~~~~~~~~~~~^~~~~~~~
|
||||
program_source:134:30: error: cannot assign to variable 'inputCol' with const-qualified type 'const uint' (aka 'const unsigned int')
|
||||
inputCol = (tileColA + innerCol);
|
||||
~~~~~~~~ ^
|
||||
program_source:123:30: note: variable 'inputCol' declared const here
|
||||
uint const inputCol;
|
||||
~~~~~~~~~~~^~~~~~~~
|
||||
program_source:159:30: error: default initialization of an object of const type 'const uint' (aka 'const unsigned int')
|
||||
uint const inputRow;
|
||||
^
|
||||
= 0
|
||||
program_source:160:30: error: default initialization of an object of const type 'const uint' (aka 'const unsigned int')
|
||||
uint const inputCol;
|
||||
^
|
||||
= 0
|
||||
program_source:170:30: error: cannot assign to variable 'inputRow' with const-qualified type 'const uint' (aka 'const unsigned int')
|
||||
inputRow = (tileRowB + innerRow);
|
||||
~~~~~~~~ ^
|
||||
program_source:159:30: note: variable 'inputRow' declared const here
|
||||
uint const inputRow;
|
||||
~~~~~~~~~~~^~~~~~~~
|
||||
program_source:171:30: error: cannot assign to variable 'inputCol' with const-qualified type 'const uint' (aka 'const unsigned int')
|
||||
inputCol = (tileCol + innerCol);
|
||||
~~~~~~~~ ^
|
||||
program_source:160:30: note: variable 'inputCol' declared const here
|
||||
uint const inputCol;
|
||||
~~~~~~~~~~~^~~~~~~~
|
||||
program_source:228:36: error: default initialization of an object of const type 'const uint' (aka 'const unsigned int')
|
||||
uint const index;
|
||||
^
|
||||
= 0
|
||||
program_source:238:33: error: cannot assign to variable 'index' with const-qualified type 'const uint' (aka 'const unsigned int')
|
||||
index = ((innerRow * ColPerThread) + innerCol);
|
||||
~~~~~ ^
|
||||
program_source:228:36: note: variable 'index' declared const here
|
||||
uint const index;
|
||||
~~~~~~~~~~~^~~~~
|
||||
program_source:270:24: error: default initialization of an object of const type 'const uint' (aka 'const unsigned int')
|
||||
uint const index;
|
||||
^
|
||||
= 0
|
||||
program_source:280:21: error: cannot assign to variable 'index' with const-qualified type 'const uint' (aka 'const unsigned int')
|
||||
index = ((innerRow * ColPerThread) + innerCol);
|
||||
~~~~~ ^
|
||||
program_source:270:24: note: variable 'index' declared const here
|
||||
uint const index;
|
||||
~~~~~~~~~~~^~~~~
|
|
@ -0,0 +1,594 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 356
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main" %tint_symbol_2 %tint_symbol %tint_symbol_1
|
||||
OpExecutionMode %main LocalSize 16 16 1
|
||||
OpName %Matrix "Matrix"
|
||||
OpMemberName %Matrix 0 "numbers"
|
||||
OpName %firstMatrix "firstMatrix"
|
||||
OpName %secondMatrix "secondMatrix"
|
||||
OpName %resultMatrix "resultMatrix"
|
||||
OpName %Uniforms "Uniforms"
|
||||
OpMemberName %Uniforms 0 "dimAOuter"
|
||||
OpMemberName %Uniforms 1 "dimInner"
|
||||
OpMemberName %Uniforms 2 "dimBOuter"
|
||||
OpName %uniforms "uniforms"
|
||||
OpName %RowPerThread "RowPerThread"
|
||||
OpName %RowPerThread "ColPerThread"
|
||||
OpName %TileAOuter "TileAOuter"
|
||||
OpName %TileAOuter "TileBOuter"
|
||||
OpName %TileAOuter "TileInner"
|
||||
OpName %mm_Asub "mm_Asub"
|
||||
OpName %mm_Bsub "mm_Bsub"
|
||||
OpName %tint_symbol "tint_symbol"
|
||||
OpName %tint_symbol_1 "tint_symbol_1"
|
||||
OpName %tint_symbol_2 "tint_symbol_2"
|
||||
OpName %mm_readA "mm_readA"
|
||||
OpName %row "row"
|
||||
OpName %col "col"
|
||||
OpName %mm_readB "mm_readB"
|
||||
OpName %row_0 "row"
|
||||
OpName %col_0 "col"
|
||||
OpName %mm_write "mm_write"
|
||||
OpName %row_1 "row"
|
||||
OpName %col_1 "col"
|
||||
OpName %value "value"
|
||||
OpName %main "main"
|
||||
OpName %acc "acc"
|
||||
OpName %ACached "ACached"
|
||||
OpName %BCached "BCached"
|
||||
OpName %index "index"
|
||||
OpName %t "t"
|
||||
OpName %innerRow "innerRow"
|
||||
OpName %innerCol "innerCol"
|
||||
OpName %innerRow_0 "innerRow"
|
||||
OpName %innerCol_0 "innerCol"
|
||||
OpName %k "k"
|
||||
OpName %inner "inner"
|
||||
OpName %innerRow_1 "innerRow"
|
||||
OpName %innerCol_1 "innerCol"
|
||||
OpName %innerRow_2 "innerRow"
|
||||
OpName %innerCol_2 "innerCol"
|
||||
OpDecorate %Matrix Block
|
||||
OpMemberDecorate %Matrix 0 Offset 0
|
||||
OpDecorate %_runtimearr_float ArrayStride 4
|
||||
OpDecorate %firstMatrix NonWritable
|
||||
OpDecorate %firstMatrix DescriptorSet 0
|
||||
OpDecorate %firstMatrix Binding 0
|
||||
OpDecorate %secondMatrix NonWritable
|
||||
OpDecorate %secondMatrix DescriptorSet 0
|
||||
OpDecorate %secondMatrix Binding 1
|
||||
OpDecorate %resultMatrix NonReadable
|
||||
OpDecorate %resultMatrix DescriptorSet 0
|
||||
OpDecorate %resultMatrix Binding 2
|
||||
OpDecorate %Uniforms Block
|
||||
OpMemberDecorate %Uniforms 0 Offset 0
|
||||
OpMemberDecorate %Uniforms 1 Offset 4
|
||||
OpMemberDecorate %Uniforms 2 Offset 8
|
||||
OpDecorate %uniforms NonWritable
|
||||
OpDecorate %uniforms DescriptorSet 0
|
||||
OpDecorate %uniforms Binding 3
|
||||
OpDecorate %_arr_float_TileAOuter ArrayStride 4
|
||||
OpDecorate %_arr__arr_float_TileAOuter_TileAOuter ArrayStride 256
|
||||
OpDecorate %tint_symbol BuiltIn LocalInvocationId
|
||||
OpDecorate %tint_symbol_1 BuiltIn GlobalInvocationId
|
||||
OpDecorate %tint_symbol_2 BuiltIn LocalInvocationIndex
|
||||
OpDecorate %_arr_float_uint_16 ArrayStride 4
|
||||
OpDecorate %_arr_float_RowPerThread ArrayStride 4
|
||||
%float = OpTypeFloat 32
|
||||
%_runtimearr_float = OpTypeRuntimeArray %float
|
||||
%Matrix = OpTypeStruct %_runtimearr_float
|
||||
%_ptr_StorageBuffer_Matrix = OpTypePointer StorageBuffer %Matrix
|
||||
%firstMatrix = OpVariable %_ptr_StorageBuffer_Matrix StorageBuffer
|
||||
%secondMatrix = OpVariable %_ptr_StorageBuffer_Matrix StorageBuffer
|
||||
%resultMatrix = OpVariable %_ptr_StorageBuffer_Matrix StorageBuffer
|
||||
%uint = OpTypeInt 32 0
|
||||
%Uniforms = OpTypeStruct %uint %uint %uint
|
||||
%_ptr_Uniform_Uniforms = OpTypePointer Uniform %Uniforms
|
||||
%uniforms = OpVariable %_ptr_Uniform_Uniforms Uniform
|
||||
%RowPerThread = OpConstant %uint 4
|
||||
%TileAOuter = OpConstant %uint 64
|
||||
%_arr_float_TileAOuter = OpTypeArray %float %TileAOuter
|
||||
%_arr__arr_float_TileAOuter_TileAOuter = OpTypeArray %_arr_float_TileAOuter %TileAOuter
|
||||
%_ptr_Workgroup__arr__arr_float_TileAOuter_TileAOuter = OpTypePointer Workgroup %_arr__arr_float_TileAOuter_TileAOuter
|
||||
%mm_Asub = OpVariable %_ptr_Workgroup__arr__arr_float_TileAOuter_TileAOuter Workgroup
|
||||
%mm_Bsub = OpVariable %_ptr_Workgroup__arr__arr_float_TileAOuter_TileAOuter Workgroup
|
||||
%v3uint = OpTypeVector %uint 3
|
||||
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
|
||||
%tint_symbol = OpVariable %_ptr_Input_v3uint Input
|
||||
%tint_symbol_1 = OpVariable %_ptr_Input_v3uint Input
|
||||
%_ptr_Input_uint = OpTypePointer Input %uint
|
||||
%tint_symbol_2 = OpVariable %_ptr_Input_uint Input
|
||||
%25 = OpTypeFunction %float %uint %uint
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%_ptr_Uniform_uint = OpTypePointer Uniform %uint
|
||||
%bool = OpTypeBool
|
||||
%uint_1 = OpConstant %uint 1
|
||||
%_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float
|
||||
%float_0 = OpConstant %float 0
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%void = OpTypeVoid
|
||||
%75 = OpTypeFunction %void %uint %uint %float
|
||||
%98 = OpTypeFunction %void
|
||||
%105 = OpConstantNull %_arr__arr_float_TileAOuter_TileAOuter
|
||||
%uint_264 = OpConstant %uint 264
|
||||
%uint_16 = OpConstant %uint 16
|
||||
%_arr_float_uint_16 = OpTypeArray %float %uint_16
|
||||
%_ptr_Function__arr_float_uint_16 = OpTypePointer Function %_arr_float_uint_16
|
||||
%129 = OpConstantNull %_arr_float_uint_16
|
||||
%_ptr_Function_float = OpTypePointer Function %float
|
||||
%132 = OpConstantNull %float
|
||||
%_arr_float_RowPerThread = OpTypeArray %float %RowPerThread
|
||||
%_ptr_Function__arr_float_RowPerThread = OpTypePointer Function %_arr_float_RowPerThread
|
||||
%136 = OpConstantNull %_arr_float_RowPerThread
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%139 = OpConstantNull %uint
|
||||
%_ptr_Workgroup_float = OpTypePointer Workgroup %float
|
||||
%mm_readA = OpFunction %float None %25
|
||||
%row = OpFunctionParameter %uint
|
||||
%col = OpFunctionParameter %uint
|
||||
%29 = OpLabel
|
||||
%32 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_0
|
||||
%33 = OpLoad %uint %32
|
||||
%34 = OpULessThan %bool %row %33
|
||||
OpSelectionMerge %36 None
|
||||
OpBranchConditional %34 %37 %36
|
||||
%37 = OpLabel
|
||||
%39 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_1
|
||||
%40 = OpLoad %uint %39
|
||||
%41 = OpULessThan %bool %col %40
|
||||
OpBranch %36
|
||||
%36 = OpLabel
|
||||
%42 = OpPhi %bool %34 %29 %41 %37
|
||||
OpSelectionMerge %43 None
|
||||
OpBranchConditional %42 %44 %43
|
||||
%44 = OpLabel
|
||||
%45 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_1
|
||||
%46 = OpLoad %uint %45
|
||||
%47 = OpIMul %uint %row %46
|
||||
%48 = OpIAdd %uint %47 %col
|
||||
%50 = OpAccessChain %_ptr_StorageBuffer_float %firstMatrix %uint_0 %48
|
||||
%51 = OpLoad %float %50
|
||||
OpReturnValue %51
|
||||
%43 = OpLabel
|
||||
OpReturnValue %float_0
|
||||
OpFunctionEnd
|
||||
%mm_readB = OpFunction %float None %25
|
||||
%row_0 = OpFunctionParameter %uint
|
||||
%col_0 = OpFunctionParameter %uint
|
||||
%56 = OpLabel
|
||||
%57 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_1
|
||||
%58 = OpLoad %uint %57
|
||||
%59 = OpULessThan %bool %row_0 %58
|
||||
OpSelectionMerge %60 None
|
||||
OpBranchConditional %59 %61 %60
|
||||
%61 = OpLabel
|
||||
%63 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_2
|
||||
%64 = OpLoad %uint %63
|
||||
%65 = OpULessThan %bool %col_0 %64
|
||||
OpBranch %60
|
||||
%60 = OpLabel
|
||||
%66 = OpPhi %bool %59 %56 %65 %61
|
||||
OpSelectionMerge %67 None
|
||||
OpBranchConditional %66 %68 %67
|
||||
%68 = OpLabel
|
||||
%69 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_2
|
||||
%70 = OpLoad %uint %69
|
||||
%71 = OpIMul %uint %row_0 %70
|
||||
%72 = OpIAdd %uint %71 %col_0
|
||||
%73 = OpAccessChain %_ptr_StorageBuffer_float %secondMatrix %uint_0 %72
|
||||
%74 = OpLoad %float %73
|
||||
OpReturnValue %74
|
||||
%67 = OpLabel
|
||||
OpReturnValue %float_0
|
||||
OpFunctionEnd
|
||||
%mm_write = OpFunction %void None %75
|
||||
%row_1 = OpFunctionParameter %uint
|
||||
%col_1 = OpFunctionParameter %uint
|
||||
%value = OpFunctionParameter %float
|
||||
%81 = OpLabel
|
||||
%82 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_0
|
||||
%83 = OpLoad %uint %82
|
||||
%84 = OpULessThan %bool %row_1 %83
|
||||
OpSelectionMerge %85 None
|
||||
OpBranchConditional %84 %86 %85
|
||||
%86 = OpLabel
|
||||
%87 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_2
|
||||
%88 = OpLoad %uint %87
|
||||
%89 = OpULessThan %bool %col_1 %88
|
||||
OpBranch %85
|
||||
%85 = OpLabel
|
||||
%90 = OpPhi %bool %84 %81 %89 %86
|
||||
OpSelectionMerge %91 None
|
||||
OpBranchConditional %90 %92 %91
|
||||
%92 = OpLabel
|
||||
%93 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_2
|
||||
%94 = OpLoad %uint %93
|
||||
%95 = OpIMul %uint %row_1 %94
|
||||
%96 = OpIAdd %uint %col_1 %95
|
||||
%97 = OpAccessChain %_ptr_StorageBuffer_float %resultMatrix %uint_0 %96
|
||||
OpStore %97 %value
|
||||
OpBranch %91
|
||||
%91 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%main = OpFunction %void None %98
|
||||
%100 = OpLabel
|
||||
%acc = OpVariable %_ptr_Function__arr_float_uint_16 Function %129
|
||||
%ACached = OpVariable %_ptr_Function_float Function %132
|
||||
%BCached = OpVariable %_ptr_Function__arr_float_RowPerThread Function %136
|
||||
%index = OpVariable %_ptr_Function_uint Function %139
|
||||
%t = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerRow = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerCol = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerRow_0 = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerCol_0 = OpVariable %_ptr_Function_uint Function %139
|
||||
%k = OpVariable %_ptr_Function_uint Function %139
|
||||
%inner = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerRow_1 = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerCol_1 = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerRow_2 = OpVariable %_ptr_Function_uint Function %139
|
||||
%innerCol_2 = OpVariable %_ptr_Function_uint Function %139
|
||||
%101 = OpLoad %uint %tint_symbol_2
|
||||
%102 = OpIEqual %bool %101 %uint_0
|
||||
OpSelectionMerge %103 None
|
||||
OpBranchConditional %102 %104 %103
|
||||
%104 = OpLabel
|
||||
OpStore %mm_Asub %105
|
||||
OpStore %mm_Bsub %105
|
||||
OpBranch %103
|
||||
%103 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
%108 = OpAccessChain %_ptr_Input_uint %tint_symbol %uint_1
|
||||
%109 = OpLoad %uint %108
|
||||
%110 = OpIMul %uint %109 %RowPerThread
|
||||
%111 = OpAccessChain %_ptr_Input_uint %tint_symbol %uint_0
|
||||
%112 = OpLoad %uint %111
|
||||
%113 = OpIMul %uint %112 %RowPerThread
|
||||
%114 = OpAccessChain %_ptr_Input_uint %tint_symbol_1 %uint_1
|
||||
%115 = OpLoad %uint %114
|
||||
%116 = OpIMul %uint %115 %RowPerThread
|
||||
%117 = OpAccessChain %_ptr_Input_uint %tint_symbol_1 %uint_0
|
||||
%118 = OpLoad %uint %117
|
||||
%119 = OpIMul %uint %118 %RowPerThread
|
||||
%120 = OpAccessChain %_ptr_Uniform_uint %uniforms %uint_1
|
||||
%121 = OpLoad %uint %120
|
||||
%122 = OpISub %uint %121 %uint_1
|
||||
%123 = OpUDiv %uint %122 %TileAOuter
|
||||
%124 = OpIAdd %uint %123 %uint_1
|
||||
OpStore %index %uint_0
|
||||
OpBranch %140
|
||||
%140 = OpLabel
|
||||
OpLoopMerge %141 %142 None
|
||||
OpBranch %143
|
||||
%143 = OpLabel
|
||||
%145 = OpLoad %uint %index
|
||||
%146 = OpIMul %uint %RowPerThread %RowPerThread
|
||||
%147 = OpULessThan %bool %145 %146
|
||||
%144 = OpLogicalNot %bool %147
|
||||
OpSelectionMerge %148 None
|
||||
OpBranchConditional %144 %149 %148
|
||||
%149 = OpLabel
|
||||
OpBranch %141
|
||||
%148 = OpLabel
|
||||
%150 = OpLoad %uint %index
|
||||
%151 = OpAccessChain %_ptr_Function_float %acc %150
|
||||
OpStore %151 %float_0
|
||||
OpBranch %142
|
||||
%142 = OpLabel
|
||||
%152 = OpLoad %uint %index
|
||||
%153 = OpIAdd %uint %152 %uint_1
|
||||
OpStore %index %153
|
||||
OpBranch %140
|
||||
%141 = OpLabel
|
||||
%154 = OpUDiv %uint %TileAOuter %uint_16
|
||||
%155 = OpAccessChain %_ptr_Input_uint %tint_symbol %uint_0
|
||||
%156 = OpLoad %uint %155
|
||||
%157 = OpIMul %uint %156 %154
|
||||
%158 = OpUDiv %uint %TileAOuter %uint_16
|
||||
%159 = OpAccessChain %_ptr_Input_uint %tint_symbol %uint_1
|
||||
%160 = OpLoad %uint %159
|
||||
%161 = OpIMul %uint %160 %158
|
||||
OpStore %t %uint_0
|
||||
OpBranch %163
|
||||
%163 = OpLabel
|
||||
OpLoopMerge %164 %165 None
|
||||
OpBranch %166
|
||||
%166 = OpLabel
|
||||
%168 = OpLoad %uint %t
|
||||
%169 = OpULessThan %bool %168 %124
|
||||
%167 = OpLogicalNot %bool %169
|
||||
OpSelectionMerge %170 None
|
||||
OpBranchConditional %167 %171 %170
|
||||
%171 = OpLabel
|
||||
OpBranch %164
|
||||
%170 = OpLabel
|
||||
OpStore %innerRow %uint_0
|
||||
OpBranch %173
|
||||
%173 = OpLabel
|
||||
OpLoopMerge %174 %175 None
|
||||
OpBranch %176
|
||||
%176 = OpLabel
|
||||
%178 = OpLoad %uint %innerRow
|
||||
%179 = OpULessThan %bool %178 %RowPerThread
|
||||
%177 = OpLogicalNot %bool %179
|
||||
OpSelectionMerge %180 None
|
||||
OpBranchConditional %177 %181 %180
|
||||
%181 = OpLabel
|
||||
OpBranch %174
|
||||
%180 = OpLabel
|
||||
OpStore %innerCol %uint_0
|
||||
OpBranch %183
|
||||
%183 = OpLabel
|
||||
OpLoopMerge %184 %185 None
|
||||
OpBranch %186
|
||||
%186 = OpLabel
|
||||
%188 = OpLoad %uint %innerCol
|
||||
%189 = OpULessThan %bool %188 %154
|
||||
%187 = OpLogicalNot %bool %189
|
||||
OpSelectionMerge %190 None
|
||||
OpBranchConditional %187 %191 %190
|
||||
%191 = OpLabel
|
||||
OpBranch %184
|
||||
%190 = OpLabel
|
||||
%192 = OpLoad %uint %innerRow
|
||||
%193 = OpIAdd %uint %110 %192
|
||||
%194 = OpLoad %uint %innerCol
|
||||
%195 = OpIAdd %uint %157 %194
|
||||
%197 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %193 %195
|
||||
%199 = OpLoad %uint %innerRow
|
||||
%200 = OpIAdd %uint %116 %199
|
||||
%201 = OpLoad %uint %t
|
||||
%202 = OpIMul %uint %201 %TileAOuter
|
||||
%203 = OpIAdd %uint %202 %195
|
||||
%198 = OpFunctionCall %float %mm_readA %200 %203
|
||||
OpStore %197 %198
|
||||
OpBranch %185
|
||||
%185 = OpLabel
|
||||
%204 = OpLoad %uint %innerCol
|
||||
%205 = OpIAdd %uint %204 %uint_1
|
||||
OpStore %innerCol %205
|
||||
OpBranch %183
|
||||
%184 = OpLabel
|
||||
OpBranch %175
|
||||
%175 = OpLabel
|
||||
%206 = OpLoad %uint %innerRow
|
||||
%207 = OpIAdd %uint %206 %uint_1
|
||||
OpStore %innerRow %207
|
||||
OpBranch %173
|
||||
%174 = OpLabel
|
||||
OpStore %innerRow_0 %uint_0
|
||||
OpBranch %209
|
||||
%209 = OpLabel
|
||||
OpLoopMerge %210 %211 None
|
||||
OpBranch %212
|
||||
%212 = OpLabel
|
||||
%214 = OpLoad %uint %innerRow_0
|
||||
%215 = OpULessThan %bool %214 %158
|
||||
%213 = OpLogicalNot %bool %215
|
||||
OpSelectionMerge %216 None
|
||||
OpBranchConditional %213 %217 %216
|
||||
%217 = OpLabel
|
||||
OpBranch %210
|
||||
%216 = OpLabel
|
||||
OpStore %innerCol_0 %uint_0
|
||||
OpBranch %219
|
||||
%219 = OpLabel
|
||||
OpLoopMerge %220 %221 None
|
||||
OpBranch %222
|
||||
%222 = OpLabel
|
||||
%224 = OpLoad %uint %innerCol_0
|
||||
%225 = OpULessThan %bool %224 %RowPerThread
|
||||
%223 = OpLogicalNot %bool %225
|
||||
OpSelectionMerge %226 None
|
||||
OpBranchConditional %223 %227 %226
|
||||
%227 = OpLabel
|
||||
OpBranch %220
|
||||
%226 = OpLabel
|
||||
%228 = OpLoad %uint %innerRow_0
|
||||
%229 = OpIAdd %uint %161 %228
|
||||
%230 = OpLoad %uint %innerCol_0
|
||||
%231 = OpIAdd %uint %113 %230
|
||||
%232 = OpLoad %uint %innerCol_0
|
||||
%233 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %232 %231
|
||||
%235 = OpLoad %uint %t
|
||||
%236 = OpIMul %uint %235 %TileAOuter
|
||||
%237 = OpIAdd %uint %236 %229
|
||||
%238 = OpLoad %uint %innerCol_0
|
||||
%239 = OpIAdd %uint %119 %238
|
||||
%234 = OpFunctionCall %float %mm_readB %237 %239
|
||||
OpStore %233 %234
|
||||
OpBranch %221
|
||||
%221 = OpLabel
|
||||
%240 = OpLoad %uint %innerCol_0
|
||||
%241 = OpIAdd %uint %240 %uint_1
|
||||
OpStore %innerCol_0 %241
|
||||
OpBranch %219
|
||||
%220 = OpLabel
|
||||
OpBranch %211
|
||||
%211 = OpLabel
|
||||
%242 = OpLoad %uint %innerRow_0
|
||||
%243 = OpIAdd %uint %242 %uint_1
|
||||
OpStore %innerRow_0 %243
|
||||
OpBranch %209
|
||||
%210 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
OpStore %k %uint_0
|
||||
OpBranch %246
|
||||
%246 = OpLabel
|
||||
OpLoopMerge %247 %248 None
|
||||
OpBranch %249
|
||||
%249 = OpLabel
|
||||
%251 = OpLoad %uint %k
|
||||
%252 = OpULessThan %bool %251 %TileAOuter
|
||||
%250 = OpLogicalNot %bool %252
|
||||
OpSelectionMerge %253 None
|
||||
OpBranchConditional %250 %254 %253
|
||||
%254 = OpLabel
|
||||
OpBranch %247
|
||||
%253 = OpLabel
|
||||
OpStore %inner %uint_0
|
||||
OpBranch %256
|
||||
%256 = OpLabel
|
||||
OpLoopMerge %257 %258 None
|
||||
OpBranch %259
|
||||
%259 = OpLabel
|
||||
%261 = OpLoad %uint %inner
|
||||
%262 = OpULessThan %bool %261 %RowPerThread
|
||||
%260 = OpLogicalNot %bool %262
|
||||
OpSelectionMerge %263 None
|
||||
OpBranchConditional %260 %264 %263
|
||||
%264 = OpLabel
|
||||
OpBranch %257
|
||||
%263 = OpLabel
|
||||
%265 = OpLoad %uint %inner
|
||||
%266 = OpAccessChain %_ptr_Function_float %BCached %265
|
||||
%267 = OpLoad %uint %k
|
||||
%268 = OpLoad %uint %inner
|
||||
%269 = OpIAdd %uint %113 %268
|
||||
%270 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %267 %269
|
||||
%271 = OpLoad %float %270
|
||||
OpStore %266 %271
|
||||
OpBranch %258
|
||||
%258 = OpLabel
|
||||
%272 = OpLoad %uint %inner
|
||||
%273 = OpIAdd %uint %272 %uint_1
|
||||
OpStore %inner %273
|
||||
OpBranch %256
|
||||
%257 = OpLabel
|
||||
OpStore %innerRow_1 %uint_0
|
||||
OpBranch %275
|
||||
%275 = OpLabel
|
||||
OpLoopMerge %276 %277 None
|
||||
OpBranch %278
|
||||
%278 = OpLabel
|
||||
%280 = OpLoad %uint %innerRow_1
|
||||
%281 = OpULessThan %bool %280 %RowPerThread
|
||||
%279 = OpLogicalNot %bool %281
|
||||
OpSelectionMerge %282 None
|
||||
OpBranchConditional %279 %283 %282
|
||||
%283 = OpLabel
|
||||
OpBranch %276
|
||||
%282 = OpLabel
|
||||
%284 = OpLoad %uint %innerRow_1
|
||||
%285 = OpIAdd %uint %110 %284
|
||||
%286 = OpLoad %uint %k
|
||||
%287 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %285 %286
|
||||
%288 = OpLoad %float %287
|
||||
OpStore %ACached %288
|
||||
OpStore %innerCol_1 %uint_0
|
||||
OpBranch %290
|
||||
%290 = OpLabel
|
||||
OpLoopMerge %291 %292 None
|
||||
OpBranch %293
|
||||
%293 = OpLabel
|
||||
%295 = OpLoad %uint %innerCol_1
|
||||
%296 = OpULessThan %bool %295 %RowPerThread
|
||||
%294 = OpLogicalNot %bool %296
|
||||
OpSelectionMerge %297 None
|
||||
OpBranchConditional %294 %298 %297
|
||||
%298 = OpLabel
|
||||
OpBranch %291
|
||||
%297 = OpLabel
|
||||
%299 = OpLoad %uint %innerRow_1
|
||||
%300 = OpIMul %uint %299 %RowPerThread
|
||||
%301 = OpLoad %uint %innerCol_1
|
||||
%302 = OpIAdd %uint %300 %301
|
||||
%303 = OpAccessChain %_ptr_Function_float %acc %302
|
||||
%304 = OpAccessChain %_ptr_Function_float %acc %302
|
||||
%305 = OpLoad %float %304
|
||||
%306 = OpLoad %float %ACached
|
||||
%307 = OpLoad %uint %innerCol_1
|
||||
%308 = OpAccessChain %_ptr_Function_float %BCached %307
|
||||
%309 = OpLoad %float %308
|
||||
%310 = OpFMul %float %306 %309
|
||||
%311 = OpFAdd %float %305 %310
|
||||
OpStore %303 %311
|
||||
OpBranch %292
|
||||
%292 = OpLabel
|
||||
%312 = OpLoad %uint %innerCol_1
|
||||
%313 = OpIAdd %uint %312 %uint_1
|
||||
OpStore %innerCol_1 %313
|
||||
OpBranch %290
|
||||
%291 = OpLabel
|
||||
OpBranch %277
|
||||
%277 = OpLabel
|
||||
%314 = OpLoad %uint %innerRow_1
|
||||
%315 = OpIAdd %uint %314 %uint_1
|
||||
OpStore %innerRow_1 %315
|
||||
OpBranch %275
|
||||
%276 = OpLabel
|
||||
OpBranch %248
|
||||
%248 = OpLabel
|
||||
%316 = OpLoad %uint %k
|
||||
%317 = OpIAdd %uint %316 %uint_1
|
||||
OpStore %k %317
|
||||
OpBranch %246
|
||||
%247 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
OpBranch %165
|
||||
%165 = OpLabel
|
||||
%319 = OpLoad %uint %t
|
||||
%320 = OpIAdd %uint %319 %uint_1
|
||||
OpStore %t %320
|
||||
OpBranch %163
|
||||
%164 = OpLabel
|
||||
OpStore %innerRow_2 %uint_0
|
||||
OpBranch %322
|
||||
%322 = OpLabel
|
||||
OpLoopMerge %323 %324 None
|
||||
OpBranch %325
|
||||
%325 = OpLabel
|
||||
%327 = OpLoad %uint %innerRow_2
|
||||
%328 = OpULessThan %bool %327 %RowPerThread
|
||||
%326 = OpLogicalNot %bool %328
|
||||
OpSelectionMerge %329 None
|
||||
OpBranchConditional %326 %330 %329
|
||||
%330 = OpLabel
|
||||
OpBranch %323
|
||||
%329 = OpLabel
|
||||
OpStore %innerCol_2 %uint_0
|
||||
OpBranch %332
|
||||
%332 = OpLabel
|
||||
OpLoopMerge %333 %334 None
|
||||
OpBranch %335
|
||||
%335 = OpLabel
|
||||
%337 = OpLoad %uint %innerCol_2
|
||||
%338 = OpULessThan %bool %337 %RowPerThread
|
||||
%336 = OpLogicalNot %bool %338
|
||||
OpSelectionMerge %339 None
|
||||
OpBranchConditional %336 %340 %339
|
||||
%340 = OpLabel
|
||||
OpBranch %333
|
||||
%339 = OpLabel
|
||||
%341 = OpLoad %uint %innerRow_2
|
||||
%342 = OpIMul %uint %341 %RowPerThread
|
||||
%343 = OpLoad %uint %innerCol_2
|
||||
%344 = OpIAdd %uint %342 %343
|
||||
%346 = OpLoad %uint %innerRow_2
|
||||
%347 = OpIAdd %uint %116 %346
|
||||
%348 = OpLoad %uint %innerCol_2
|
||||
%349 = OpIAdd %uint %119 %348
|
||||
%350 = OpAccessChain %_ptr_Function_float %acc %344
|
||||
%351 = OpLoad %float %350
|
||||
%345 = OpFunctionCall %void %mm_write %347 %349 %351
|
||||
OpBranch %334
|
||||
%334 = OpLabel
|
||||
%352 = OpLoad %uint %innerCol_2
|
||||
%353 = OpIAdd %uint %352 %uint_1
|
||||
OpStore %innerCol_2 %353
|
||||
OpBranch %332
|
||||
%333 = OpLabel
|
||||
OpBranch %324
|
||||
%324 = OpLabel
|
||||
%354 = OpLoad %uint %innerRow_2
|
||||
%355 = OpIAdd %uint %354 %uint_1
|
||||
OpStore %innerRow_2 %355
|
||||
OpBranch %322
|
||||
%323 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,231 @@
|
|||
[[block]]
|
||||
struct Uniforms {
|
||||
dimAOuter : u32;
|
||||
dimInner : u32;
|
||||
dimBOuter : u32;
|
||||
};
|
||||
|
||||
[[block]]
|
||||
struct Matrix {
|
||||
numbers : array<f32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> firstMatrix : Matrix;
|
||||
|
||||
[[group(0), binding(1)]] var<storage, read> secondMatrix : Matrix;
|
||||
|
||||
[[group(0), binding(2)]] var<storage, write> resultMatrix : Matrix;
|
||||
|
||||
[[group(0), binding(3)]] var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn mm_readA(row : u32, col : u32) -> f32 {
|
||||
if (((row < uniforms.dimAOuter) && (col < uniforms.dimInner))) {
|
||||
let result : f32 = firstMatrix.numbers[((row * uniforms.dimInner) + col)];
|
||||
return result;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
fn mm_readB(row : u32, col : u32) -> f32 {
|
||||
if (((row < uniforms.dimInner) && (col < uniforms.dimBOuter))) {
|
||||
let result : f32 = secondMatrix.numbers[((row * uniforms.dimBOuter) + col)];
|
||||
return result;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
fn mm_write(row : u32, col : u32, value : f32) {
|
||||
if (((row < uniforms.dimAOuter) && (col < uniforms.dimBOuter))) {
|
||||
let index : u32 = (col + (row * uniforms.dimBOuter));
|
||||
resultMatrix.numbers[index] = value;
|
||||
}
|
||||
}
|
||||
|
||||
let RowPerThread : u32 = 4u;
|
||||
|
||||
let ColPerThread : u32 = 4u;
|
||||
|
||||
let TileAOuter : u32 = 64u;
|
||||
|
||||
let TileBOuter : u32 = 64u;
|
||||
|
||||
let TileInner : u32 = 64u;
|
||||
|
||||
var<workgroup> mm_Asub : array<array<f32, 64>, 64>;
|
||||
|
||||
var<workgroup> mm_Bsub : array<array<f32, 64>, 64>;
|
||||
|
||||
[[stage(compute), workgroup_size(16, 16, 1)]]
|
||||
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>, [[builtin(global_invocation_id)]] global_id : vec3<u32>) {
|
||||
let tileRow : u32 = (local_id.y * RowPerThread);
|
||||
let tileCol : u32 = (local_id.x * ColPerThread);
|
||||
let globalRow : u32 = (global_id.y * RowPerThread);
|
||||
let globalCol : u32 = (global_id.x * ColPerThread);
|
||||
let numTiles : u32 = (((uniforms.dimInner - 1u) / TileInner) + 1u);
|
||||
var acc : array<f32, 16>;
|
||||
var ACached : f32;
|
||||
var BCached : array<f32, 4>;
|
||||
{
|
||||
var index : u32 = 0u;
|
||||
loop {
|
||||
if (!((index < (RowPerThread * ColPerThread)))) {
|
||||
break;
|
||||
}
|
||||
acc[index] = 0.0;
|
||||
|
||||
continuing {
|
||||
index = (index + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
let ColPerThreadA : u32 = (TileInner / 16u);
|
||||
let tileColA : u32 = (local_id.x * ColPerThreadA);
|
||||
let RowPerThreadB : u32 = (TileInner / 16u);
|
||||
let tileRowB : u32 = (local_id.y * RowPerThreadB);
|
||||
{
|
||||
var t : u32 = 0u;
|
||||
loop {
|
||||
if (!((t < numTiles))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var innerRow : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerRow < RowPerThread))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var innerCol : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerCol < ColPerThreadA))) {
|
||||
break;
|
||||
}
|
||||
let inputRow : u32 = (tileRow + innerRow);
|
||||
let inputCol : u32 = (tileColA + innerCol);
|
||||
mm_Asub[inputRow][inputCol] = mm_readA((globalRow + innerRow), ((t * TileInner) + inputCol));
|
||||
|
||||
continuing {
|
||||
innerCol = (innerCol + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing {
|
||||
innerRow = (innerRow + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var innerRow : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerRow < RowPerThreadB))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var innerCol : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerCol < ColPerThread))) {
|
||||
break;
|
||||
}
|
||||
let inputRow : u32 = (tileRowB + innerRow);
|
||||
let inputCol : u32 = (tileCol + innerCol);
|
||||
mm_Bsub[innerCol][inputCol] = mm_readB(((t * TileInner) + inputRow), (globalCol + innerCol));
|
||||
|
||||
continuing {
|
||||
innerCol = (innerCol + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing {
|
||||
innerRow = (innerRow + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
{
|
||||
var k : u32 = 0u;
|
||||
loop {
|
||||
if (!((k < TileInner))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var inner : u32 = 0u;
|
||||
loop {
|
||||
if (!((inner < ColPerThread))) {
|
||||
break;
|
||||
}
|
||||
BCached[inner] = mm_Bsub[k][(tileCol + inner)];
|
||||
|
||||
continuing {
|
||||
inner = (inner + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var innerRow : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerRow < RowPerThread))) {
|
||||
break;
|
||||
}
|
||||
ACached = mm_Asub[(tileRow + innerRow)][k];
|
||||
{
|
||||
var innerCol : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerCol < ColPerThread))) {
|
||||
break;
|
||||
}
|
||||
let index : u32 = ((innerRow * ColPerThread) + innerCol);
|
||||
acc[index] = (acc[index] + (ACached * BCached[innerCol]));
|
||||
|
||||
continuing {
|
||||
innerCol = (innerCol + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing {
|
||||
innerRow = (innerRow + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing {
|
||||
k = (k + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
continuing {
|
||||
t = (t + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
var innerRow : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerRow < RowPerThread))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var innerCol : u32 = 0u;
|
||||
loop {
|
||||
if (!((innerCol < ColPerThread))) {
|
||||
break;
|
||||
}
|
||||
let index : u32 = ((innerRow * ColPerThread) + innerCol);
|
||||
mm_write((globalRow + innerRow), (globalCol + innerCol), acc[index]);
|
||||
|
||||
continuing {
|
||||
innerCol = (innerCol + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing {
|
||||
innerRow = (innerRow + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue