use GGML_F32_EPR, and remove some dead code.

This commit is contained in:
Julia Longtin 2024-04-03 22:04:45 +00:00
parent 84df774d6a
commit 9ad5efafb0

View file

@ -6,13 +6,10 @@
// Yes, we have to tell this header to actually export stuff. // Yes, we have to tell this header to actually export stuff.
#define GGML_COMMON_IMPL_C #define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-impl.h" #include "ggml-impl.h"
// FIXME: why do we have to import this twice? // For block_q5_K and block_q8_K.
#define GGML_COMMON_IMPL_C
// For block_q5_K and block_q8_K. only given the second time.
#include "ggml-common.h" #include "ggml-common.h"
// This SIMD unit can work with 32 float32s at once. // This SIMD unit can work with 32 float32s at once.
@ -213,23 +210,17 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
utmp[2] = uaux; utmp[2] = uaux;
utmp[0] &= kmask1; utmp[0] &= kmask1;
a = (int8_t * restrict)aux8;
int sumi = 0;
GGML_I32x16_VEC_ZERO(&aux32);
// FIXME: while comparing FMA output to the original output, the original had an error. hunt it down. // FIXME: while comparing FMA output to the original output, the original had an error. hunt it down.
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16(q8copy, aux8, scales, &aux32); GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16(q8copy, aux8, scales, &aux32);
int sumi = 0; int sumi = 0;
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 16; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l]; for (int l = 0; l < GGML_F32_EPR; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l];
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf -= dmin * sumi; sumf -= dmin * sumi;
} }
for (int l = 0; l < 16; ++l) sumf += ((float *)&sums)[l]; for (int l = 0; l < GGML_F32_EPR; ++l) sumf += ((float *)&sums)[l];
*s = sumf; *s = sumf;
} }