ggml : q5_0 scalar dot product

This commit is contained in:
Georgi Gerganov 2023-04-26 13:58:47 +03:00
parent 99238e4c28
commit ef8e3ee6f5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

42
ggml.c
View file

@ -3167,18 +3167,16 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
}
static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
GGML_ASSERT(false); // TODO xxxxxxxxx
const int nb = n / QK8_1;
assert(n % QK8_1 == 0);
assert(nb % 2 == 0);
assert(QK8_1 == 2*QK5_0);
assert(QK8_1 == QK5_0);
const block_q5_0 * restrict x = vx;
const block_q8_1 * restrict y = vy;
#if defined(__ARM_NEON)
#if defined(__ARM_NEON_XXX)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -3257,43 +3255,37 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
*s = hsum_float_8(acc) + summs;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const uint8_t * restrict x0 = x[2*i + 0].qs;
const uint8_t * restrict x1 = x[2*i + 1].qs;
const uint8_t * restrict x0 = x[i].qs;
const int8_t * restrict y0 = y[i].qs;
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
const uint32_t qh = x[i].qh;
int sxy_0 = 0;
int sxy_1 = 0;
const float d = GGML_FP16_TO_FP32(x[i].d);
const float m = GGML_FP16_TO_FP32(x[i].m);
for (int j = 0; j < QK8_1/4; j++) {
int sxy = 0;
for (int j = 0; j < QK8_1/2; j++) {
const uint8_t v0 = x0[j];
const uint8_t v1 = x1[j];
const int x0_0 = v0 & 0x0F;
const int x1_0 = v0 >> 4;
const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
const int x0_1 = v1 & 0x0F;
const int x1_1 = v1 >> 4;
const int x0_0 = (v0 & 0x0F) | x0_0h;
const int x1_0 = (v0 >> 4) | x1_0h;
const int y0_0 = y0[2*j + 0];
const int y1_0 = y0[2*j + 1];
const int y0_1 = y0[2*(j + QK8_1/4) + 0];
const int y1_1 = y0[2*(j + QK8_1/4) + 1];
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
sxy += x0_0*y0_0 + x1_0*y1_0;
}
sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
}
*s = sumf;
#endif
}