perform 16 operations at a time.

This commit is contained in:
Julia Longtin 2024-03-24 12:04:44 +00:00
parent d34e0ff835
commit 0c01d07835

View file

@ -24,6 +24,8 @@
typedef float float32x8_t __attribute__((vector_size (64))); typedef float float32x8_t __attribute__((vector_size (64)));
typedef int16_t int16x8_t __attribute__((vector_size (32))); typedef int16_t int16x8_t __attribute__((vector_size (32)));
typedef int32_t int32x8_t __attribute__((vector_size (64))); typedef int32_t int32x8_t __attribute__((vector_size (64)));
typedef int16_t int16x16_t __attribute__((vector_size (64)));
typedef int32_t int32x16_t __attribute__((vector_size (128)));
/* A forward declaration, to keep GCC happy. */ /* A forward declaration, to keep GCC happy. */
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc); void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc);
@ -58,6 +60,19 @@ inline static void GGML_I32x8_VEC_ZERO(int32x8_t *target)
: "zmm8", "k1", "memory"); : "zmm8", "k1", "memory");
} }
inline static void GGML_I32x16_VEC_ZERO(int32x8_t *target)
{
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
__asm__ __volatile__ (
"vbroadcastI32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our value.
"kmov\t%[M],\t%%k1\n\t"
"vmovaps\t\t%%zmm8,\t%[RES]%{%%k1%}\n\t"
: [RES] "+m" (*target)
: [Z] "m" (zero)
: "zmm8", "k1", "memory");
}
// perform an eight wide Fused Multiply Add of an I16x8 times scalar S into I32x8. // perform an eight wide Fused Multiply Add of an I16x8 times scalar S into I32x8.
inline static void GGML_I16x8_S_FMA_I32x8 (int16x8_t *src, int32_t scale, int32x8_t *dest) inline static void GGML_I16x8_S_FMA_I32x8 (int16x8_t *src, int32_t scale, int32x8_t *dest)
{ {
@ -66,15 +81,12 @@ inline static void GGML_I16x8_S_FMA_I32x8 (int16x8_t *src, int32_t scale, int32x
int32_t scaleVec[4] = {scale, scale, scale, scale}; int32_t scaleVec[4] = {scale, scale, scale, scale};
__asm__ __volatile__ ( __asm__ __volatile__ (
"vbroadcastI32x4\t%[Z]%{uint8%},\t%%zmm0\n\t" // use an upscaling operator to clear our value.
"vbroadcastI32x4\t%[Z]%{uint8%},\t%%zmm1\n\t" // use an upscaling operator to clear our value.
"vbroadcastI32x4\t%[Z]%{uint8%},\t%%zmm2\n\t" // use an upscaling operator to clear our value.
"kmov\t%[M],\t%%k1\n\t" // we will only be working with 8 values at a time. le sigh. "kmov\t%[M],\t%%k1\n\t" // we will only be working with 8 values at a time. le sigh.
"vmovdqa32\t\t%[SRC]%{sint16%},\t%%zmm0%{%%k1%}\n\t" // load the item we will be summing from. upscale it from int16. "vmovdqa32\t\t%[SRC]%{sint16%},\t%%zmm0%{%%k1%}\n\t" // load the item we will be summing from. upscale it from int16.
"vbroadcastI32x4\t%[SCALE],\t%%zmm1\n\t" // load the item we will be multiplying by. "vbroadcastI32x4\t%[SCALE],\t%%zmm1\n\t" // load the item we will be multiplying by.
"vmovdqa32\t\t%[RES],\t%%zmm2%{%%k1%}\n\t" // load the item we will be summing onto. "vmovdqa32\t\t%[RES],\t%%zmm2%{%%k1%}\n\t" // load the item we will be summing onto.
"vpmadd231d\t%%zmm0,\t%%zmm1,\t%%zmm2%{%%k1%}\n\t" // perform our multiply-add. "vpmadd231d\t%%zmm0,\t%%zmm1,\t%%zmm2%{%%k1%}\n\t" // perform our multiply-add.
"vmovdqa32\t\t%%zmm2,\t%[RES]%{%%k1}\n\t" // save the result. "vmovdqa32\t\t%%zmm2,\t%[RES]%{%%k1}\n\t" // save the result.
: [RES] "+m" (*dest) : [RES] "+m" (*dest)
: [Z] "m" (zero), : [Z] "m" (zero),
[M] "r" (mask), [M] "r" (mask),
@ -83,6 +95,23 @@ inline static void GGML_I16x8_S_FMA_I32x8 (int16x8_t *src, int32_t scale, int32x
: "zmm0", "zmm1", "zmm2", "k1", "memory"); : "zmm0", "zmm1", "zmm2", "k1", "memory");
} }
// perform an eight wide Fused Multiply Add of an I16x16 times scalar S into I32x16.
inline static void GGML_I16x16_S_FMA_I32x16 (int16x8_t *src, int32_t scale, int32x8_t *dest)
{
int32_t scaleVec[4] = {scale, scale, scale, scale};
__asm__ __volatile__ (
"vmovdqa32\t\t%[SRC]%{sint16%},\t%%zmm0\n\t" // load the item we will be summing from. upscale it from int16.
"vbroadcastI32x4\t%[SCALE],\t%%zmm1\n\t" // load the item we will be multiplying by.
"vmovdqa32\t\t%[RES],\t%%zmm2\n\t" // load the item we will be summing onto.
"vpmadd231d\t%%zmm0,\t%%zmm1,\t%%zmm2\n\t" // perform our multiply-add.
"vmovdqa32\t\t%%zmm2,\t%[RES]\n\t" // save the result.
: [RES] "+m" (*dest)
: [SRC] "m" (src),
[SCALE] "m" (scaleVec)
: "zmm0", "zmm1", "zmm2", "k1", "memory");
}
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
/* interpret X and Y as vectors. */ /* interpret X and Y as vectors. */
@ -101,19 +130,20 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2]; const uint8_t * mins = (const uint8_t*)&utmp[2];
int8_t aux8[QK_K]; int8_t aux8[QK_K];
int16x8_t aux16 __attribute__((aligned(64))); int16x16_t aux16 __attribute__((aligned(128)));
float32x8_t sums __attribute__((aligned(64))); float32x16_t sums __attribute__((aligned(64)));
int32x8_t aux32 __attribute__((aligned(64))); int32x16_t aux32 __attribute__((aligned(128)));
GGML_F32x8_VEC_ZERO(&sums); GGML_F32x16_VEC_ZERO(&sums);
float sumf = 0; float sumf = 0;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
const uint8_t * restrict hm = x[i].qh; const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
GGML_I32x8_VEC_ZERO(&aux32);
GGML_I32x16_VEC_ZERO(&aux32);
int8_t * restrict a = aux8; int8_t * restrict a = aux8;
uint8_t m = 1; uint8_t m = 1;
@ -139,24 +169,19 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
int is = 0; int is = 0;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
int32_t scale = scales[is++]; int32_t scale = scales[is++];
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l]; for (int l = 0; l < 16; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l]; GGML_I16x8_S_FMA_I32x16 (&aux16, scale, &aux32);
q8 += 8; a += 8; q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l]; /* FIXME: while comparing FMA output to normal output, the original had an error. hunt it down. */
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l]; for (int l = 0; l < 16; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
q8 += 8; a += 8; GGML_I16x8_S_FMA_I32x16 (&aux16, scale, &aux32);
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l]; q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) ((int32_t *)&aux32)[l] += scale * ((int16_t *)&aux16)[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) ((int16_t *)&aux16)[l] = q8[l] * a[l];
GGML_I16x8_S_FMA_I32x8 (&aux16, scale, &aux32);
q8 += 8; a += 8;
} }
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l]; for (int l = 0; l < 16; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l];
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf -= dmin * sumi; sumf -= dmin * sumi;
} }
for (int l = 0; l < 8; ++l) sumf += ((float *)&sums)[l]; for (int l = 0; l < 16; ++l) sumf += ((float *)&sums)[l];
*s = sumf; *s = sumf;
} }