From e211fadc8a52af67bad3f543a67000ae471363ec Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 4 Jan 2024 02:06:23 +0100 Subject: [PATCH] iq2_xxs: slighty faster dot product TG-128 is now 50.9 t/s --- ggml-metal.m | 12 ++++++++++-- ggml-metal.metal | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e422dbb4e..47e474d55 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1708,10 +1708,14 @@ bool ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_IQ2_XXS || + //src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_IQ2_XXS) { + [encoder setThreadgroupMemoryLength:256*8 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1972,10 +1976,14 @@ bool ggml_metal_graph_compute( if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || - src2t == GGML_TYPE_IQ2_XXS || + //src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_IQ2_XXS) { + [encoder setThreadgroupMemoryLength:256*8 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 94fc71235..0ada17bef 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3569,6 +3569,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3594,6 +3595,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + { + const int nval = 4; + const int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + #if QK_K == 256 const int ix = tiisg; @@ -3621,7 +3630,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { - constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[l]); + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); @@ -3668,11 +3677,12 @@ kernel void kernel_mul_mv_iq2_xxs_f32( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } //============================= templates and their specializations ============================= @@ -5403,6 +5413,7 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( device const char * src05, device const char * src06, device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], @@ -5428,6 +5439,7 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( ne1, r2, r3, + shared_values, tgpig, tiisg, sgitg);