loosen alignment requirements for zeros, add missing function, and promote aux8 to an array of vectors.

This commit is contained in:
Julia Longtin 2024-03-24 13:35:05 +00:00
parent 1c182a3896
commit e579af1e95

View file

@ -35,7 +35,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
/* clear a vector of 8 floats. */ /* clear a vector of 8 floats. */
inline static void GGML_F32x8_VEC_ZERO(float32x8_t *target) inline static void GGML_F32x8_VEC_ZERO(float32x8_t *target)
{ {
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0}; uint8_t zero[4] __attribute__((aligned(32))) = {0,0,0,0};
uint32_t mask=0x000000FF; uint32_t mask=0x000000FF;
__asm__ __volatile__ ( __asm__ __volatile__ (
@ -48,10 +48,23 @@ inline static void GGML_F32x8_VEC_ZERO(float32x8_t *target)
: "zmm8", "k1", "memory"); : "zmm8", "k1", "memory");
} }
/* clear a vector of 16 floats. */
inline static void GGML_F32x16_VEC_ZERO(float32x16_t *target)
{
uint8_t zero[4] __attribute__((aligned(32))) = {0,0,0,0};
__asm__ __volatile__ (
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our register.
"vmovaps\t\t%%zmm8,\t%[RES]\n\t"
: [RES] "+m" (*target)
: [Z] "m" (zero)
: "zmm8", "memory");
}
/* clear a vector of 8 int32_ts. */ /* clear a vector of 8 int32_ts. */
inline static void GGML_I32x8_VEC_ZERO(int32x8_t *target) inline static void GGML_I32x8_VEC_ZERO(int32x8_t *target)
{ {
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0}; uint8_t zero[4] __attribute__((aligned(32))) = {0,0,0,0};
uint32_t mask=0x000000FF; uint32_t mask=0x000000FF;
__asm__ __volatile__ ( __asm__ __volatile__ (
@ -67,7 +80,7 @@ inline static void GGML_I32x8_VEC_ZERO(int32x8_t *target)
/* clear a vector of 16 int32_ts. */ /* clear a vector of 16 int32_ts. */
inline static void GGML_I32x16_VEC_ZERO(int32x16_t *target) inline static void GGML_I32x16_VEC_ZERO(int32x16_t *target)
{ {
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0}; uint8_t zero[4] __attribute__((aligned(32))) = {0,0,0,0};
__asm__ __volatile__ ( __asm__ __volatile__ (
"vbroadcastI32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our register. "vbroadcastI32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our register.
@ -132,9 +145,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2]; const uint8_t * mins = (const uint8_t*)&utmp[2];
int8_t aux8[QK_K];
int8x16_t aux8x16[QK_K/16] __attribute__((aligned(32)));
float32x16_t sums __attribute__((aligned(128))); float32x16_t sums __attribute__((aligned(128)));
int8x16_t aux8[QK_K/16] __attribute__((aligned(32)));
int16x16_t aux16 __attribute__((aligned(64))); int16x16_t aux16 __attribute__((aligned(64)));
int32x16_t aux32 __attribute__((aligned(128))); int32x16_t aux32 __attribute__((aligned(128)));
@ -146,8 +158,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * restrict hm = x[i].qh; const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
int8_t * restrict a = aux8_16; int8_t * restrict a = (int8_t * restrict)aux8;
uint8_t m = 1; uint8_t m = 1;
// Fill the 8 bit vector a with our 5 bit quantization data, 64 blocks at a time.
for (int j = 0; j < QK_K/64; ++j) { for (int j = 0; j < QK_K/64; ++j) {
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
@ -157,12 +171,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
a += 32; m <<= 1; a += 32; m <<= 1;
q4 += 32; q4 += 32;
} }
memcpy(utmp, x[i].scales, 12); memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1; const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux; utmp[2] = uaux;
utmp[0] &= kmask1; utmp[0] &= kmask1;
a = (int8_t * restrict)aux8;
int sumi = 0; int sumi = 0;