diff --git a/ggml-metal.metal b/ggml-metal.metal index 30d60fa58..8867dcfce 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32( // each thread in a SIMD group deals with 1 block. for (int column = 0; column < nb / N_SIMDWIDTH; column++) { + float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } + sumy *= (-8.f); for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -405,39 +408,50 @@ kernel void kernel_mul_mat_q4_0_f32( // calculate float d = qb_curr.d; - float2 acc = {0.0f, 0.0f}; + float acc = sumy; for (int i = 0; i < 16; i++) { - acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - acc[1] += yl[i] + yl[i+16]; + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); } - sumf[row] += d * (acc[0] - 8.f*acc[1]); + sumf[row] += d * acc; qb_curr = qb_next; } } - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float2 acc = {0.0f, 0.0f}; - for (int i = 0; i < 16; i++) { - acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - acc[1] += yl[i] + yl[i+16]; + if (nb % N_SIMDWIDTH == 0) { + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * (acc[0] - 8.f*acc[1]); - } - qb_curr = qb_next; + } else { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + } + sumy *= (-8.f); + + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; + + // calculate + float d = qb_curr.d; + float acc = sumy; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + if (tiisg < nb % N_SIMDWIDTH) { + sumf[row] += d * acc; + } + qb_curr = qb_next; + + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } } }