From 0c9cca00bd984afb35cc303f7fb5ff49abbd57be Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 25 Jun 2023 09:54:40 +0200 Subject: [PATCH] Write coalescing --- ggml-vulkan-matmul.comp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml-vulkan-matmul.comp b/ggml-vulkan-matmul.comp index 05e24f9ae..8fc894c37 100644 --- a/ggml-vulkan-matmul.comp +++ b/ggml-vulkan-matmul.comp @@ -31,8 +31,10 @@ void main() { const int ir = int(gl_WorkGroupID.x); const int ic = int(gl_WorkGroupID.y); - const int lr = int(gl_LocalInvocationID.x % (BM/TM)); - const int lc = int(gl_LocalInvocationID.x / (BM/TM)); + const int rstride = BM / TM; + + const int lr = int(gl_LocalInvocationID.x % rstride); + const int lc = int(gl_LocalInvocationID.x / rstride); const int loadr = int(gl_LocalInvocationID.x % BK); const int loadc = int(gl_LocalInvocationID.x / BK); @@ -50,7 +52,7 @@ void main() { sums[i] = 0.0f; } - for (int block = 0; block < p.K; block += BK) { + [[unroll]] for (int block = 0; block < p.K; block += BK) { [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { const int lr = l % BK; const int lc = l / BK; @@ -69,8 +71,8 @@ void main() { [[unroll]] for (int i = 0; i < BK; i++) { // Load from shared into cache - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[j] = buf_a[(lr * TM + j) * (BK+1) + i]; + [[unroll]] for (int j = 0; j < BM; j++) { + cache_a[j] = buf_a[(lr + j*rstride) * (BK+1) + i]; } [[unroll]] for (int j = 0; j < TN; j++) { cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i]; @@ -86,12 +88,12 @@ void main() { barrier(); } - const int dr = ir * BM + lr * TM; + const int dr = ir * BM + lr; const int dc = ic * BN + lc * TN; [[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cr = 0; cr < TM; cr++) { - data_d[(dc + cc) * p.stride_d + dr + cr] = sums[cc * TM + cr]; + data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr]; } } }