copy right block.
This commit is contained in:
parent
e99f3a9bf4
commit
656bf28c91
1 changed files with 48 additions and 28 deletions
|
@ -30,7 +30,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
inline static void GGML_F32x8_VEC_ZERO(float32x8_t *target)
|
||||
{
|
||||
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
||||
uint32_t mask=0x000000FF;
|
||||
uint32_t mask=0x0000FF00;
|
||||
|
||||
__asm__ __volatile__ (
|
||||
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our value.
|
||||
|
@ -55,43 +55,63 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
uint32_t utmp[4];
|
||||
int8_t aux8[QK_K];
|
||||
// int16_t aux16[16];
|
||||
int16x16_t aux16 __attribute__((aligned(64)));
|
||||
float32x8_t sums __attribute__((aligned(64)));
|
||||
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
||||
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
||||
|
||||
/* use a vector operation to clear these floats. */
|
||||
GGML_F32x8_VEC_ZERO(&sums);
|
||||
int8_t aux8[QK_K];
|
||||
int16_t aux16[8];
|
||||
float sums [8];
|
||||
int32_t aux32[8];
|
||||
memset(sums, 0, 8*sizeof(float));
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// quants, 4 low bits.
|
||||
const uint8_t * restrict q4 = x[i].qs;
|
||||
// quants, 1 high bit.
|
||||
const uint8_t * restrict hm = x[i].qh;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
memset(aux32, 0, 8*sizeof(int32_t));
|
||||
int8_t * restrict a = aux8;
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
a[l+ 0] = q4[l] & 0xF;
|
||||
a[l+32] = q4[l] >> 4;
|
||||
}
|
||||
for (int is = 0; is < 8; ++is) {
|
||||
uint8_t m = 1 << is;
|
||||
for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
|
||||
uint8_t m = 1;
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
||||
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
||||
a += 32; m <<= 1;
|
||||
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
||||
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
||||
a += 32; m <<= 1;
|
||||
q4 += 32;
|
||||
}
|
||||
memcpy(utmp, x[i].scales, 12);
|
||||
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux = utmp[1] & kmask1;
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
const int8_t * restrict sc = x[i].scales;
|
||||
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
const float dl = d * sc[j];
|
||||
for (int l = 0; l < 16; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) ((float *)&sums)[l] += dl * (((int16_t *)&aux16)[l] + ((int16_t *)&aux16)[8+l]);
|
||||
q8 += 16; a += 16;
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
||||
a = aux8;
|
||||
int is = 0;
|
||||
for (int j = 0; j < QK_K/32; ++j) {
|
||||
int32_t scale = scales[is++];
|
||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
||||
q8 += 8; a += 8;
|
||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
||||
q8 += 8; a += 8;
|
||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
||||
q8 += 8; a += 8;
|
||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
||||
q8 += 8; a += 8;
|
||||
}
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
||||
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
||||
sumf -= dmin * sumi;
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) sumf += ((float *)&sums)[l];
|
||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||
*s = sumf;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue