From 065cc8cb474be7945d2997047dd926c644899cc5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 4 Jan 2024 02:25:59 +0100 Subject: [PATCH] iq2_xxs: even faster Metal dot product TG-128 is now 54.1 t/s. Strangely enough, putting the signs lookup table into shared memory has a bigger impact than the grid values being in shared memory. --- ggml-metal.m | 4 ++-- ggml-metal.metal | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 47e474d55..43536e5f4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1713,7 +1713,7 @@ bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS) { - [encoder setThreadgroupMemoryLength:256*8 atIndex:0]; + [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { @@ -1981,7 +1981,7 @@ bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS) { - [encoder setThreadgroupMemoryLength:256*8 atIndex:0]; + [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src2t == GGML_TYPE_Q4_K) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 0ada17bef..0cc535ac7 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3596,10 +3596,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); { - const int nval = 4; - const int pos = (32*sgitg + tiisg)*nval; + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -3631,7 +3635,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); }