try to use vectorized zeroing function.
This commit is contained in:
parent
2870bfc6dd
commit
7a00422fa3
1 changed files with 8 additions and 5 deletions
|
@ -30,7 +30,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
inline static void GGML_F32x8_VEC_ZERO(float32x8_t *target)
|
inline static void GGML_F32x8_VEC_ZERO(float32x8_t *target)
|
||||||
{
|
{
|
||||||
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
||||||
uint32_t mask=0x0000FF00;
|
uint32_t mask=0x0000000F;
|
||||||
|
|
||||||
__asm__ __volatile__ (
|
__asm__ __volatile__ (
|
||||||
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our value.
|
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our value.
|
||||||
|
@ -62,9 +62,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
int8_t aux8[QK_K];
|
int8_t aux8[QK_K];
|
||||||
int16_t aux16[8];
|
int16_t aux16[8];
|
||||||
float sums [8];
|
float32x8_t sums;
|
||||||
int32_t aux32[8];
|
int32_t aux32[8];
|
||||||
memset(sums, 0, 8*sizeof(float));
|
|
||||||
|
//memset(sums, 0, 8*sizeof(float));
|
||||||
|
|
||||||
|
GGML_F32x8_VEC_ZERO(&sums);
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
@ -110,10 +113,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
q8 += 8; a += 8;
|
q8 += 8; a += 8;
|
||||||
}
|
}
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
for (int l = 0; l < 8; ++l) ((float *)&sums)[l] += d * aux32[l];
|
||||||
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
||||||
sumf -= dmin * sumi;
|
sumf -= dmin * sumi;
|
||||||
}
|
}
|
||||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
for (int l = 0; l < 8; ++l) sumf += ((float *)&sums)[l];
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue