diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c94bc9389..16636c2a7 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3507,8 +3507,6 @@ void kernel_mul_mv_q2_K_f32_impl( float yl[32]; float sumf[N_DST]={0.f}, all_sum; - const int step = nb01; - const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -3553,9 +3551,9 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += step/2; - sc += step; - dh += step/2; + qs += nb01/2; + sc += nb01; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3678,8 +3676,6 @@ void kernel_mul_mv_q3_K_f32_impl( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - const int step = nb01 / 2; - device const float * y1 = yy + ix*QK_K + y_offset; uint32_t scales32, aux32; @@ -3741,10 +3737,10 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += step; - h += step; - a += step; - dh += step; + q += nb01/2; + h += nb01/2; + a += nb01/2; + dh += nb01/2; } y1 += 4 * QK_K; @@ -3844,8 +3840,6 @@ void kernel_mul_mv_q4_K_f32_impl( float yh[16]; float sumf[N_DST]={0.f}, all_sum; - const int step = nb01 / 2; - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; uint16_t sc16[4]; @@ -3893,9 +3887,9 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - sc += step; - dh += step; + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3982,8 +3976,6 @@ void kernel_mul_mv_q5_K_f32_impl( float sumf[2]={0.f}; - const int step = nb01; - float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; @@ -4054,10 +4046,10 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - qh += step; - dh += step/2; - a += step/2; + q1 += nb01; + qh += nb01; + dh += nb01/2; + a += nb01/2; } y1 += 4 * QK_K;