cont : use nb01 directly for row steps

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-25 16:27:03 +03:00
parent 5409a21e1b
commit 18989be340
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -3507,8 +3507,6 @@ void kernel_mul_mv_q2_K_f32_impl(
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int step = nb01;
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
const int iq = it/4; // 0 or 1
@ -3553,9 +3551,9 @@ void kernel_mul_mv_q2_K_f32_impl(
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
qs += step/2;
sc += step;
dh += step/2;
qs += nb01/2;
sc += nb01;
dh += nb01/2;
}
y4 += 4 * QK_K;
@ -3678,8 +3676,6 @@ void kernel_mul_mv_q3_K_f32_impl(
const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0;
const int step = nb01 / 2;
device const float * y1 = yy + ix*QK_K + y_offset;
uint32_t scales32, aux32;
@ -3741,10 +3737,10 @@ void kernel_mul_mv_q3_K_f32_impl(
sumf1[row] += d1 * (scales[1] - 32);
sumf2[row] += d2 * (scales[3] - 32);
q += step;
h += step;
a += step;
dh += step;
q += nb01/2;
h += nb01/2;
a += nb01/2;
dh += nb01/2;
}
y1 += 4 * QK_K;
@ -3844,8 +3840,6 @@ void kernel_mul_mv_q4_K_f32_impl(
float yh[16];
float sumf[N_DST]={0.f}, all_sum;
const int step = nb01 / 2;
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
uint16_t sc16[4];
@ -3893,9 +3887,9 @@ void kernel_mul_mv_q4_K_f32_impl(
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step;
sc += step;
dh += step;
q1 += nb01/2;
sc += nb01/2;
dh += nb01/2;
}
y4 += 4 * QK_K;
@ -3982,8 +3976,6 @@ void kernel_mul_mv_q5_K_f32_impl(
float sumf[2]={0.f};
const int step = nb01;
float yl[16], yh[16];
const uint16_t kmask1 = 0x3f3f;
@ -4054,10 +4046,10 @@ void kernel_mul_mv_q5_K_f32_impl(
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step;
qh += step;
dh += step/2;
a += step/2;
q1 += nb01;
qh += nb01;
dh += nb01/2;
a += nb01/2;
}
y1 += 4 * QK_K;