diff --git a/ggml-metal.metal b/ggml-metal.metal index c6e14f46e..e47d97bea 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -29,6 +29,7 @@ typedef struct { typedef struct { half d; // delta half m; // min + uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; @@ -2235,10 +2236,10 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const uint32_t qh = *((device const uint32_t *)xb->qh); - const int x_mv = (il ? 4 : 0); + const int x_mv = il ? 4 : 0; - const int gh_mv = (il ? 12 : 0); - const int gh_bk = (il ? 0 : 4); + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 @@ -2256,7 +2257,30 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg template void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - // TODO + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } } template