From 1b4863c2b939a60842fbcf39f53be94a6687c2ce Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 24 Jun 2023 08:01:43 +0200 Subject: [PATCH] 1D Blocktiling --- ggml-vulkan-matmul.comp | 57 ++++++++++++++++++++++++----------------- ggml-vulkan.cpp | 6 +++-- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/ggml-vulkan-matmul.comp b/ggml-vulkan-matmul.comp index 4341a51ac..3166bb659 100644 --- a/ggml-vulkan-matmul.comp +++ b/ggml-vulkan-matmul.comp @@ -1,10 +1,13 @@ #version 450 -#define BLOCKSIZE 32 +#define BM 64 +#define BN 64 +#define BK 8 +#define TM 8 #extension GL_EXT_control_flow_attributes : enable -layout(local_size_x = BLOCKSIZE * BLOCKSIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = (BM * BN) / TM, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A { float data_a[]; }; layout (binding = 1) readonly buffer B { float data_b[]; }; @@ -20,42 +23,50 @@ layout (push_constant) uniform parameter int stride_d; } p; -shared float buf_a[(BLOCKSIZE+1) * BLOCKSIZE]; -shared float buf_b[(BLOCKSIZE+1) * BLOCKSIZE]; - -void main() -{ - const int lr = int(gl_LocalInvocationID.x % BLOCKSIZE); - const int lc = int(gl_LocalInvocationID.x / BLOCKSIZE); +shared float buf_a[BM * (BK+1)]; +shared float buf_b[BN * (BK+1)]; +void main() { const int ir = int(gl_WorkGroupID.x); const int ic = int(gl_WorkGroupID.y); - int pos_a = ir * BLOCKSIZE * p.stride_a; - int pos_b = ic * BLOCKSIZE * p.stride_b; + const int lr = int(gl_LocalInvocationID.x % BK); + const int lc = int(gl_LocalInvocationID.x / BK); - float sum = 0.0f; + int pos_a = ir * BM * p.stride_a; + int pos_b = ic * BN * p.stride_b; - [[unroll]] for (int i = 0; i < p.K; i += BLOCKSIZE) { - buf_a[lc * (BLOCKSIZE+1) + lr] = data_a[pos_a + lc * p.stride_a + lr]; - buf_b[lc * (BLOCKSIZE+1) + lr] = data_b[pos_b + lc * p.stride_b + lr]; + float sums[TM]; + float btmp; + + [[unroll]] for (int k = 0; k < TM; k++) { + sums[k] = 0.0f; + } + + for (int i = 0; i < p.K; i += BK) { + // A is transposed + buf_a[lc * (BK+1) + lr] = data_a[pos_a + lc * p.stride_a + lr]; + buf_b[lc * (BK+1) + lr] = data_b[pos_b + lc * p.stride_b + lr]; barrier(); - pos_a += BLOCKSIZE; - pos_b += BLOCKSIZE; + pos_a += BK; + pos_b += BK; - [[unroll]] for (int j = 0; j < BLOCKSIZE; j++) { - sum += buf_a[lr * (BLOCKSIZE+1) + j] * buf_b[lc * (BLOCKSIZE+1) + j]; + [[unroll]] for (int j = 0; j < BK; j++) { + btmp = buf_b[lc * (BK+1) + j]; + [[unroll]] for (int k = 0; k < TM; k++) { + sums[k] += buf_a[(lr * TM + k) * (BK+1) + j] * btmp; + } } barrier(); } - const int dr = ir * BLOCKSIZE + lr; - const int dc = ic * BLOCKSIZE + lc; + const int dr = ir * BM + lr * TM; + const int dc = ic * BN + lc; - if (dr < p.M && dc < p.N) { - data_d[dc * p.stride_d + dr] = sum; + [[unroll]] for (int k = 0; k < TM; k++) { + data_d[dc * p.stride_d + dr + k] = sums[k]; } } diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 77de265dc..133f2a399 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -38,6 +38,8 @@ inline static void* ggml_aligned_malloc(size_t size, size_t alignment) { #define VK_API_VERSION VK_API_VERSION_1_2 +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + vk::Instance vk_instance; uint32_t vk_compute_queue_family_index; vk::PhysicalDevice vk_physical_device; @@ -470,7 +472,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr 0, { descriptor_set }, {}); - cmd_buffer.dispatch((ne01 + 31) / 32, (ne11 + 31) / 32, 1); + cmd_buffer.dispatch(CEIL_DIV(ne01, 64), CEIL_DIV(ne11, 64), 1); cmd_buffer.end(); vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0); @@ -494,7 +496,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne); -#ifdef false +#if 0 const float * x = (float *) ((char *) src0->data); const float * y = (float *) ((char *) src1->data); float * d_chk = (float *) malloc(sizeof(float) * d_ne);