promote aux32 to a vector.

This commit is contained in:
Julia Longtin 2024-03-23 21:12:35 +00:00
parent 3c29fd57ce
commit 10237df57a

View file

@ -23,6 +23,7 @@
typedef float float32x8_t __attribute__((vector_size (64)));
typedef int16_t int16x8_t __attribute__((vector_size (32)));
typedef int32_t int32x8_t __attribute__((vector_size (64)));
/* A forward declaration, to keep GCC happy. */
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc);
@ -63,7 +64,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
int8_t aux8[QK_K];
int16x8_t aux16 __attribute__((aligned(64)));
float32x8_t sums __attribute__((aligned(64)));
int32_t aux32[8];
int32x8_t aux32 __attribute__((aligned(64)));
GGML_F32x8_VEC_ZERO(&sums);
@ -98,20 +99,20 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int j = 0; j < QK_K/32; ++j) {
int32_t scale = scales[is++];
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)&aux16)[l];
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)&aux16)[l];
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)&aux16)[l];
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)&aux16)[l];
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l];
q8 += 8; a += 8;
}
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) ((float *)&sums)[l] += d * aux32[l];
for (int l = 0; l < 8; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l];
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf -= dmin * sumi;
}