diff --git a/ggml-metal.metal b/ggml-metal.metal index e47d97bea..9df4b8cb4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -466,7 +466,18 @@ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thre // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { - // TODO + float d = qb_curr->d; + float m = qb_curr->m; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x0010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x1000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_1/2) << 4 ) & 0x0010)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_1/2) << 12) & 0x1000)); + } + return d * (acc[0] + acc[1]) + sumy * m; } // putting them in the kernel cause a significant performance penalty