Add matmul shader support for running multiple calculations in parallel
This commit is contained in:
parent
e9e2be33fd
commit
7b36cea8a3
3 changed files with 32745 additions and 30864 deletions
62051
ggml-vulkan-shaders.hpp
62051
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
1343
ggml-vulkan.cpp
1343
ggml-vulkan.cpp
File diff suppressed because it is too large
Load diff
|
@ -259,69 +259,86 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
int M;
|
uint M;
|
||||||
int N;
|
uint N;
|
||||||
int K;
|
uint K;
|
||||||
int stride_a;
|
uint stride_a;
|
||||||
int stride_b;
|
uint stride_b;
|
||||||
int stride_d;
|
uint stride_d;
|
||||||
int k_split;
|
uint k_split;
|
||||||
int d_offset;
|
uint d_offset;
|
||||||
|
|
||||||
|
uint ne02;
|
||||||
|
uint ne12;
|
||||||
|
uint broadcast2;
|
||||||
|
uint broadcast3;
|
||||||
|
|
||||||
|
uint batch_stride_a;
|
||||||
|
uint batch_stride_b;
|
||||||
|
uint batch_stride_d;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
layout (constant_id = 1) const int BM = 64;
|
layout (constant_id = 1) const uint BM = 64;
|
||||||
layout (constant_id = 2) const int BN = 64;
|
layout (constant_id = 2) const uint BN = 64;
|
||||||
layout (constant_id = 3) const int BK = 16;
|
layout (constant_id = 3) const uint BK = 16;
|
||||||
layout (constant_id = 4) const int WM = 32;
|
layout (constant_id = 4) const uint WM = 32;
|
||||||
layout (constant_id = 5) const int WN = 32;
|
layout (constant_id = 5) const uint WN = 32;
|
||||||
layout (constant_id = 6) const int WMITER = 2;
|
layout (constant_id = 6) const uint WMITER = 2;
|
||||||
layout (constant_id = 7) const int TM = 4;
|
layout (constant_id = 7) const uint TM = 4;
|
||||||
layout (constant_id = 8) const int TN = 2;
|
layout (constant_id = 8) const uint TN = 2;
|
||||||
|
|
||||||
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
||||||
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int blocks_x = (p.M + BM - 1) / BM;
|
const uint i13 = gl_GlobalInvocationID.z / p.ne12;
|
||||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
const uint i12 = gl_GlobalInvocationID.z % p.ne12;
|
||||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
|
||||||
const int ic = int(gl_WorkGroupID.y);
|
|
||||||
|
|
||||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
const uint i03 = i13 / p.broadcast3;
|
||||||
const int warp_r = warp_i % (BM / WM);
|
const uint i02 = i12 / p.broadcast2;
|
||||||
const int warp_c = warp_i / (BM / WM);
|
|
||||||
|
|
||||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
const int WSUBM = WM / WMITER;
|
|
||||||
const int WSUBN = WN / WNITER;
|
|
||||||
|
|
||||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
const uint blocks_x = (p.M + BM - 1) / BM;
|
||||||
const int tiwr = tiw % (WSUBM / TM);
|
const uint ir = gl_WorkGroupID.x % blocks_x;
|
||||||
const int tiwc = tiw / (WSUBM / TM);
|
const uint ik = gl_WorkGroupID.x / blocks_x;
|
||||||
|
const uint ic = gl_WorkGroupID.y;
|
||||||
|
|
||||||
const int loadr = int(gl_LocalInvocationID.x % (BK / LOAD_VEC));
|
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||||
const int loadc = int(gl_LocalInvocationID.x / (BK / LOAD_VEC));
|
const uint warp_r = warp_i % (BM / WM);
|
||||||
|
const uint warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
const int loadstride = int(gl_WorkGroupSize.x * LOAD_VEC) / BK;
|
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
|
const uint WSUBM = WM / WMITER;
|
||||||
|
const uint WSUBN = WN / WNITER;
|
||||||
|
|
||||||
const int start_k = ik * p.k_split;
|
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||||
const int end_k = min(p.K, (ik + 1) * p.k_split);
|
const uint tiwr = tiw % (WSUBM / TM);
|
||||||
|
const uint tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
int pos_a = ir * BM * p.stride_a / LOAD_VEC + start_k / LOAD_VEC;
|
const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
|
||||||
int pos_b = ic * BN * p.stride_b / LOAD_VEC + start_k / LOAD_VEC;
|
const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
|
||||||
|
|
||||||
|
const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK;
|
||||||
|
|
||||||
|
const uint start_k = ik * p.k_split;
|
||||||
|
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
||||||
|
|
||||||
|
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC;
|
||||||
|
uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC;
|
||||||
|
|
||||||
float sums[WMITER * TM * WNITER * TN];
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
FLOAT_TYPE cache_a[WMITER * TM];
|
FLOAT_TYPE cache_a[WMITER * TM];
|
||||||
FLOAT_TYPE cache_b[WNITER * TN];
|
FLOAT_TYPE cache_b[WNITER * TN];
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
sums[i] = 0.0f;
|
sums[i] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
|
||||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
[[unroll]] for (uint l = 0; l < BM; l += loadstride) {
|
||||||
#if LOAD_VEC == 8
|
#if LOAD_VEC == 8
|
||||||
const int idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
|
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||||
|
@ -331,7 +348,7 @@ void main() {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||||
#elif LOAD_VEC == 4
|
#elif LOAD_VEC == 4
|
||||||
const int idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
|
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||||
|
@ -344,9 +361,9 @@ void main() {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
[[unroll]] for (uint l = 0; l < BN; l += loadstride) {
|
||||||
#if LOAD_VEC == 8
|
#if LOAD_VEC == 8
|
||||||
const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
|
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
||||||
|
@ -356,7 +373,7 @@ void main() {
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||||
#elif LOAD_VEC == 4
|
#elif LOAD_VEC == 4
|
||||||
const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
|
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
|
||||||
|
@ -375,23 +392,23 @@ void main() {
|
||||||
pos_a += BK / LOAD_VEC;
|
pos_a += BK / LOAD_VEC;
|
||||||
pos_b += BK / LOAD_VEC;
|
pos_b += BK / LOAD_VEC;
|
||||||
|
|
||||||
for (int i = 0; i < BK; i++) {
|
for (uint i = 0; i < BK; i++) {
|
||||||
// Load from shared into cache
|
// Load from shared into cache
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
[[unroll]] for (uint j = 0; j < TN; j++) {
|
||||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -402,20 +419,20 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int dr = ir * BM + warp_r * WM;
|
const uint dr = ir * BM + warp_r * WM;
|
||||||
const int dc = ic * BN + warp_c * WN;
|
const uint dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
const int k_split_offset = ik * p.M * p.N;
|
const uint k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
data_d[p.d_offset + k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
data_d[p.d_offset + gl_GlobalInvocationID.z * p.batch_stride_d + k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue