From dd296101531af342332d00e23c729454121aee14 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 3 Jan 2024 18:59:14 +0100 Subject: [PATCH] iq2_xxs: Metal dot product now works We have PP-512 = 475 t/s TG-128 = 47.3 t/s Not the greatest performance, but not complete garbage either. --- ggml-metal.metal | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 33fe32dcb..a6f1b1745 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3601,10 +3601,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y4 = y + 32 * ix; - uint32_t aux32[2]; - thread const uint8_t * aux8 = (thread const uint8_t *)aux32; - thread uint16_t * aux16 = (thread uint16_t *)aux32; - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { for (int i = 0; i < 32; ++i) { @@ -3620,13 +3616,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const float db = xr->d; device const uint16_t * q2 = xr->qs + 4 * ib; - for (int i = 0; i < 4; ++i) aux16[i] = q2[i]; - const float d = db * (0.5f + (aux32[1] >> 28)); + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); float sum = 0; for (int l = 0; l < 4; ++l) { constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[l]); - const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); }