diff --git a/ggml-metal.metal b/ggml-metal.metal index a345b83d9..69fc71362 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -414,8 +414,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -432,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -449,9 +455,12 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); @@ -468,9 +477,12 @@ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; + float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); @@ -2250,7 +2262,7 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const int x_mv = il ? 4 : 0; const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; + const int gh_bk = il ? 0 : 4; for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 @@ -2258,7 +2270,7 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; // combine the 4-bits from qs with the 5th bit - const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); reg[i/2][2*(i%2)+0] = d * x0 + md; @@ -2278,7 +2290,7 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg const int x_mv = il ? 4 : 0; const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; + const int gh_bk = il ? 0 : 4; for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 @@ -2286,7 +2298,7 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; // combine the 4-bits from qs with the 5th bit - const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0); + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); reg[i/2][2*(i%2)+0] = d * x0 + m;