From 7c6860b483fb04a4b922d3a54ac21097dcba0db0 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 24 Jun 2023 18:40:11 +0200 Subject: [PATCH] 2D Blocktiling --- ggml-vulkan-matmul.comp | 65 ++++++++++++++++++++++++++++------------- ggml-vulkan.cpp | 2 +- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/ggml-vulkan-matmul.comp b/ggml-vulkan-matmul.comp index 3166bb659..05e24f9ae 100644 --- a/ggml-vulkan-matmul.comp +++ b/ggml-vulkan-matmul.comp @@ -1,13 +1,14 @@ #version 450 -#define BM 64 -#define BN 64 +#define BM 128 +#define BN 128 #define BK 8 #define TM 8 +#define TN 8 #extension GL_EXT_control_flow_attributes : enable -layout(local_size_x = (BM * BN) / TM, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A { float data_a[]; }; layout (binding = 1) readonly buffer B { float data_b[]; }; @@ -30,33 +31,55 @@ void main() { const int ir = int(gl_WorkGroupID.x); const int ic = int(gl_WorkGroupID.y); - const int lr = int(gl_LocalInvocationID.x % BK); - const int lc = int(gl_LocalInvocationID.x / BK); + const int lr = int(gl_LocalInvocationID.x % (BM/TM)); + const int lc = int(gl_LocalInvocationID.x / (BM/TM)); + + const int loadr = int(gl_LocalInvocationID.x % BK); + const int loadc = int(gl_LocalInvocationID.x / BK); + + const int loadstride = int(gl_WorkGroupSize.x); int pos_a = ir * BM * p.stride_a; int pos_b = ic * BN * p.stride_b; - float sums[TM]; - float btmp; + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; - [[unroll]] for (int k = 0; k < TM; k++) { - sums[k] = 0.0f; + [[unroll]] for (int i = 0; i < TM*TN; i++) { + sums[i] = 0.0f; } - for (int i = 0; i < p.K; i += BK) { - // A is transposed - buf_a[lc * (BK+1) + lr] = data_a[pos_a + lc * p.stride_a + lr]; - buf_b[lc * (BK+1) + lr] = data_b[pos_b + lc * p.stride_b + lr]; + for (int block = 0; block < p.K; block += BK) { + [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { + const int lr = l % BK; + const int lc = l / BK; + buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; + } + [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { + const int lr = l % BK; + const int lc = l / BK; + buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; + } barrier(); pos_a += BK; pos_b += BK; - [[unroll]] for (int j = 0; j < BK; j++) { - btmp = buf_b[lc * (BK+1) + j]; - [[unroll]] for (int k = 0; k < TM; k++) { - sums[k] += buf_a[(lr * TM + k) * (BK+1) + j] * btmp; + [[unroll]] for (int i = 0; i < BK; i++) { + // Load from shared into cache + [[unroll]] for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(lr * TM + j) * (BK+1) + i]; + } + [[unroll]] for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i]; + } + + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + sums[cc * TM + cr] += cache_a[cr] * cache_b[cc]; + } } } @@ -64,9 +87,11 @@ void main() { } const int dr = ir * BM + lr * TM; - const int dc = ic * BN + lc; + const int dc = ic * BN + lc * TN; - [[unroll]] for (int k = 0; k < TM; k++) { - data_d[dc * p.stride_d + dr + k] = sums[k]; + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + data_d[(dc + cc) * p.stride_d + dr + cr] = sums[cc * TM + cr]; + } } } diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 133f2a399..99265f990 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -472,7 +472,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr 0, { descriptor_set }, {}); - cmd_buffer.dispatch(CEIL_DIV(ne01, 64), CEIL_DIV(ne11, 64), 1); + cmd_buffer.dispatch(CEIL_DIV(ne01, 128), CEIL_DIV(ne11, 128), 1); cmd_buffer.end(); vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);