promote aux16 to a vector.

This commit is contained in:
Julia Longtin 2024-03-23 21:00:51 +00:00
parent 31b8a5afd7
commit 45c94bd89d

View file

@ -22,7 +22,7 @@
#define GGML_F32_EPR 16 #define GGML_F32_EPR 16
typedef float float32x8_t __attribute__((vector_size (64))); typedef float float32x8_t __attribute__((vector_size (64)));
typedef int16_t int16x16_t __attribute__((vector_size (64))); typedef int16_t int16x8_t __attribute__((vector_size (32)));
/* A forward declaration, to keep GCC happy. */ /* A forward declaration, to keep GCC happy. */
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc); void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc);
@ -61,7 +61,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * mins = (const uint8_t*)&utmp[2]; const uint8_t * mins = (const uint8_t*)&utmp[2];
int8_t aux8[QK_K]; int8_t aux8[QK_K];
int16_t aux16[8]; int16x8_t aux16 __attribute__((aligned(64)));
float32x8_t sums __attribute__((aligned(64))); float32x8_t sums __attribute__((aligned(64)));
int32_t aux32[8]; int32_t aux32[8];
@ -97,17 +97,17 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
int is = 0; int is = 0;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
int32_t scale = scales[is++]; int32_t scale = scales[is++];
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) ((int16_t *)aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)aux16)[l];
q8 += 8; a += 8; q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) ((int16_t *)aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)aux16)[l];
q8 += 8; a += 8; q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) ((int16_t *)aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)aux16)[l];
q8 += 8; a += 8; q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) ((int16_t *)aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * ((int16_t *)aux16)[l];
q8 += 8; a += 8; q8 += 8; a += 8;
} }
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;