Add matmul shader support for running multiple calculations in parallel

This commit is contained in:
0cc4m 2023-12-30 12:36:24 +01:00
parent e9e2be33fd
commit 7b36cea8a3
3 changed files with 32745 additions and 30864 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -259,69 +259,86 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int stride_a;
int stride_b;
int stride_d;
int k_split;
int d_offset;
uint M;
uint N;
uint K;
uint stride_a;
uint stride_b;
uint stride_d;
uint k_split;
uint d_offset;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
} p;
layout (constant_id = 1) const int BM = 64;
layout (constant_id = 2) const int BN = 64;
layout (constant_id = 3) const int BK = 16;
layout (constant_id = 4) const int WM = 32;
layout (constant_id = 5) const int WN = 32;
layout (constant_id = 6) const int WMITER = 2;
layout (constant_id = 7) const int TM = 4;
layout (constant_id = 8) const int TN = 2;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16;
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2;
shared FLOAT_TYPE buf_a[BM * (BK+1)];
shared FLOAT_TYPE buf_b[BN * (BK+1)];
void main() {
const int blocks_x = (p.M + BM - 1) / BM;
const int ir = int(gl_WorkGroupID.x) % blocks_x;
const int ik = int(gl_WorkGroupID.x) / blocks_x;
const int ic = int(gl_WorkGroupID.y);
const uint i13 = gl_GlobalInvocationID.z / p.ne12;
const uint i12 = gl_GlobalInvocationID.z % p.ne12;
const int warp_i = int(gl_LocalInvocationID.x / WARP);
const int warp_r = warp_i % (BM / WM);
const int warp_c = warp_i / (BM / WM);
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const int WSUBM = WM / WMITER;
const int WSUBN = WN / WNITER;
const uint batch_idx_a = i03 * p.ne02 + i02;
const int tiw = int(gl_LocalInvocationID.x % WARP);
const int tiwr = tiw % (WSUBM / TM);
const int tiwc = tiw / (WSUBM / TM);
const uint blocks_x = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_x;
const uint ik = gl_WorkGroupID.x / blocks_x;
const uint ic = gl_WorkGroupID.y;
const int loadr = int(gl_LocalInvocationID.x % (BK / LOAD_VEC));
const int loadc = int(gl_LocalInvocationID.x / (BK / LOAD_VEC));
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const int loadstride = int(gl_WorkGroupSize.x * LOAD_VEC) / BK;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
const int start_k = ik * p.k_split;
const int end_k = min(p.K, (ik + 1) * p.k_split);
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
int pos_a = ir * BM * p.stride_a / LOAD_VEC + start_k / LOAD_VEC;
int pos_b = ic * BN * p.stride_b / LOAD_VEC + start_k / LOAD_VEC;
const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK;
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC;
uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC;
float sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN];
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f;
}
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride) {
#if LOAD_VEC == 8
const int idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
@ -331,7 +348,7 @@ void main() {
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC == 4
const int idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
@ -344,9 +361,9 @@ void main() {
}
#endif
}
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
[[unroll]] for (uint l = 0; l < BN; l += loadstride) {
#if LOAD_VEC == 8
const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
@ -356,7 +373,7 @@ void main() {
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC == 4
const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
@ -375,23 +392,23 @@ void main() {
pos_a += BK / LOAD_VEC;
pos_b += BK / LOAD_VEC;
for (int i = 0; i < BK; i++) {
for (uint i = 0; i < BK; i++) {
// Load from shared into cache
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (int j = 0; j < TM; j++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
}
}
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (int j = 0; j < TN; j++) {
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
}
}
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) {
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
}
}
@ -402,20 +419,20 @@ void main() {
barrier();
}
const int dr = ir * BM + warp_r * WM;
const int dc = ic * BN + warp_c * WN;
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
const int k_split_offset = ik * p.M * p.N;
const uint k_split_offset = ik * p.M * p.N;
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
[[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) {
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[p.d_offset + k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[p.d_offset + gl_GlobalInvocationID.z * p.batch_stride_d + k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
}
}