replace tabs with spaces.
This commit is contained in:
parent
9152143fe7
commit
53773e0b4a
2 changed files with 211 additions and 211 deletions
|
@ -35,11 +35,11 @@ inline static void GGML_F32x16_VEC_ZERO(float32x16_t *target)
|
||||||
uint8_t zero=0;
|
uint8_t zero=0;
|
||||||
|
|
||||||
__asm__ __volatile__ (
|
__asm__ __volatile__ (
|
||||||
"vbroadcastss\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our register.
|
"vbroadcastss\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our register.
|
||||||
"vmovaps\t\t%%zmm8,\t%[RES]\n\t"
|
"vmovaps\t\t%%zmm8,\t%[RES]\n\t"
|
||||||
: [RES] "+m" (*target)
|
: [RES] "+m" (*target)
|
||||||
: [Z] "m" (zero)
|
: [Z] "m" (zero)
|
||||||
: "zmm8", "memory");
|
: "zmm8", "memory");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,52 +50,52 @@ inline static void GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16 (int8x16_t
|
||||||
uint8_t zero = 0;
|
uint8_t zero = 0;
|
||||||
|
|
||||||
__asm__ __volatile__ (
|
__asm__ __volatile__ (
|
||||||
"vprefetche0\t(%[SRC11])\n\t"
|
"vprefetche0\t(%[SRC11])\n\t"
|
||||||
"vprefetche0\t(%[SRC21])\n\t"
|
"vprefetche0\t(%[SRC21])\n\t"
|
||||||
"vprefetche0\t(%[SCALE])\n\t"
|
"vprefetche0\t(%[SCALE])\n\t"
|
||||||
"mov\t$0,\t%%ecx\n\t"
|
"mov\t$0,\t%%ecx\n\t"
|
||||||
"mov\t%[SRC11],\t%%r12\n\t"
|
"mov\t%[SRC11],\t%%r12\n\t"
|
||||||
"mov\t%[SRC21],\t%%r8\n\t"
|
"mov\t%[SRC21],\t%%r8\n\t"
|
||||||
"mov\t%[SCALE],\t%%r9\n\t"
|
"mov\t%[SCALE],\t%%r9\n\t"
|
||||||
"vpbroadcastd\t%[Z]%{uint8%},\t%%zmm7\n\t" // empty our result.
|
"vpbroadcastd\t%[Z]%{uint8%},\t%%zmm7\n\t" // empty our result.
|
||||||
"1:\n\t"
|
"1:\n\t"
|
||||||
"inc\t%%ecx\n\t" // we are in our loop, increment our counter.
|
"inc\t%%ecx\n\t" // we are in our loop, increment our counter.
|
||||||
"cmp\t$4,\t%%ecx\n\t" // see if this is our last run-through.
|
"cmp\t$4,\t%%ecx\n\t" // see if this is our last run-through.
|
||||||
"vmovdqa32\t\t(%%r12)%{sint8%},\t%%zmm0\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
"vmovdqa32\t\t(%%r12)%{sint8%},\t%%zmm0\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||||
"vmovdqa32\t\t(%%r8)%{uint8%},\t%%zmm1\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
"vmovdqa32\t\t(%%r8)%{uint8%},\t%%zmm1\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||||
"vpmulld\t%%zmm0,\t%%zmm1,\t%%zmm2\n\t" // perform our 64 bit multiply, low side.
|
"vpmulld\t%%zmm0,\t%%zmm1,\t%%zmm2\n\t" // perform our 64 bit multiply, low side.
|
||||||
"vpbroadcastd\t(%%r9)%{uint8%},\t%%zmm6\n\t" // load the item we will be multiplying by.
|
"vpbroadcastd\t(%%r9)%{uint8%},\t%%zmm6\n\t" // load the item we will be multiplying by.
|
||||||
"vpmadd231d\t%%zmm2,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
"vpmadd231d\t%%zmm2,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||||
"vmovdqa32\t\t16(%%r12)%{sint8%},\t%%zmm3\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
"vmovdqa32\t\t16(%%r12)%{sint8%},\t%%zmm3\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||||
"vmovdqa32\t\t16(%%r8)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
"vmovdqa32\t\t16(%%r8)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||||
"vpmulld\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // perform our 64 bit multiply, low side.
|
"vpmulld\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // perform our 64 bit multiply, low side.
|
||||||
"vpmadd231d\t%%zmm5,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
"vpmadd231d\t%%zmm5,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||||
"vmovdqa32\t\t32(%%r12)%{sint8%},\t%%zmm8\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
"vmovdqa32\t\t32(%%r12)%{sint8%},\t%%zmm8\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||||
"vmovdqa32\t\t32(%%r8)%{uint8%},\t%%zmm1\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
"vmovdqa32\t\t32(%%r8)%{uint8%},\t%%zmm1\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||||
"vpmulld\t%%zmm8,\t%%zmm1,\t%%zmm2\n\t" // perform our 64 bit multiply, low side.
|
"vpmulld\t%%zmm8,\t%%zmm1,\t%%zmm2\n\t" // perform our 64 bit multiply, low side.
|
||||||
"vpbroadcastd\t1(%%r9)%{uint8%},\t%%zmm6\n\t" // load the item we will be multiplying by.
|
"vpbroadcastd\t1(%%r9)%{uint8%},\t%%zmm6\n\t" // load the item we will be multiplying by.
|
||||||
"vpmadd231d\t%%zmm2,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
"vpmadd231d\t%%zmm2,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||||
"vmovdqa32\t\t48(%%r12)%{sint8%},\t%%zmm3\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
"vmovdqa32\t\t48(%%r12)%{sint8%},\t%%zmm3\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||||
"vmovdqa32\t\t48(%%r8)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
"vmovdqa32\t\t48(%%r8)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||||
"vpmulld\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // perform our 64 bit multiply, low side.
|
"vpmulld\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // perform our 64 bit multiply, low side.
|
||||||
"vpmadd231d\t%%zmm5,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
"vpmadd231d\t%%zmm5,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||||
"je\t2f\n\t" // if this is the last time through our loop, jump to 2.
|
"je\t2f\n\t" // if this is the last time through our loop, jump to 2.
|
||||||
"vprefetche0\t64(%%r12)\n\t" // otherwise, prepare for another run-through.
|
"vprefetche0\t64(%%r12)\n\t" // otherwise, prepare for another run-through.
|
||||||
"vprefetche0\t64(%%r8)\n\t"
|
"vprefetche0\t64(%%r8)\n\t"
|
||||||
"vprefetche2\t128(%%r12)\n\t"
|
"vprefetche2\t128(%%r12)\n\t"
|
||||||
"vprefetche2\t128(%%r8)\n\t"
|
"vprefetche2\t128(%%r8)\n\t"
|
||||||
"add\t$64,\t%%r12\n\t"
|
"add\t$64,\t%%r12\n\t"
|
||||||
"add\t$64,\t%%r8\n\t"
|
"add\t$64,\t%%r8\n\t"
|
||||||
"add\t$2,\t%%r9\n\t"
|
"add\t$2,\t%%r9\n\t"
|
||||||
"jmp\t1b\n\t"
|
"jmp\t1b\n\t"
|
||||||
"2:\n\t"
|
"2:\n\t"
|
||||||
"vmovdqa32\t\t%%zmm7,\t(%[RES])\n\t" // save the result.
|
"vmovdqa32\t\t%%zmm7,\t(%[RES])\n\t" // save the result.
|
||||||
: [RES] "+r" (res)
|
: [RES] "+r" (res)
|
||||||
: [SRC11] "r" (src11),
|
: [SRC11] "r" (src11),
|
||||||
[SRC21] "r" (src21),
|
[SRC21] "r" (src21),
|
||||||
[SCALE] "r" (scale),
|
[SCALE] "r" (scale),
|
||||||
[Z] "m" (zero)
|
[Z] "m" (zero)
|
||||||
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "r8", "r9", "r12", "memory");
|
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "r8", "r9", "r12", "memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unpack 256 unsigned 5 bit values into an 8 bit vector.
|
// Unpack 256 unsigned 5 bit values into an 8 bit vector.
|
||||||
|
@ -107,55 +107,55 @@ inline static void GGML_5bit_Unpack (const uint8x16_t * q4, const uint8_t * q1,
|
||||||
uint8_t bit5 = 0x10;
|
uint8_t bit5 = 0x10;
|
||||||
|
|
||||||
__asm__ __volatile__ (
|
__asm__ __volatile__ (
|
||||||
"vprefetche0\t(%[SRC1])\n\t" // Issue our memory requests first thing.
|
"vprefetche0\t(%[SRC1])\n\t" // Issue our memory requests first thing.
|
||||||
"vprefetche0\t(%[SRC4])\n\t"
|
"vprefetche0\t(%[SRC4])\n\t"
|
||||||
"vprefetche1\t64(%[SRC4])\n\t"
|
"vprefetche1\t64(%[SRC4])\n\t"
|
||||||
"mov\t%[SRC4],\t%%r12\n\t" // load the address of the head of our 4-bit list.
|
"mov\t%[SRC4],\t%%r12\n\t" // load the address of the head of our 4-bit list.
|
||||||
"mov\t%[DST],\t%%r8\n\t" // load the address of the head of our destination list.
|
"mov\t%[DST],\t%%r8\n\t" // load the address of the head of our destination list.
|
||||||
"mov\t$0,%%ecx\n\t" // initialize our counter.
|
"mov\t$0,%%ecx\n\t" // initialize our counter.
|
||||||
"vmovdqa32\t(%[SRC1])%{uint8%},\t%%zmm6\n\t" // move 16 packed sets of single bits into the lower 8 bits of zmm6.
|
"vmovdqa32\t(%[SRC1])%{uint8%},\t%%zmm6\n\t" // move 16 packed sets of single bits into the lower 8 bits of zmm6.
|
||||||
"vmovdqa32\t16(%[SRC1])%{uint8%},\t%%zmm7\n\t" // move the next 16 packed sets of single bits into the lower 8 bits of zmm7.
|
"vmovdqa32\t16(%[SRC1])%{uint8%},\t%%zmm7\n\t" // move the next 16 packed sets of single bits into the lower 8 bits of zmm7.
|
||||||
"vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm2\n\t " // load our mask.
|
"vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm2\n\t " // load our mask.
|
||||||
"vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm9\n\t" // load the bit we want to add (conditionally).
|
"vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm9\n\t" // load the bit we want to add (conditionally).
|
||||||
"vpbroadcastd\t%[M]%{uint8%},\t%%zmm8\n\t" // select which bit we want to test for.
|
"vpbroadcastd\t%[M]%{uint8%},\t%%zmm8\n\t" // select which bit we want to test for.
|
||||||
"1:\n\t"
|
"1:\n\t"
|
||||||
"inc\t%%ecx\n\t" // we are in the loop. increment the counter.
|
"inc\t%%ecx\n\t" // we are in the loop. increment the counter.
|
||||||
"vptestmd\t%%zmm6,\t%%zmm8,\t%%k1\n\t" // perform our test.
|
"vptestmd\t%%zmm6,\t%%zmm8,\t%%k1\n\t" // perform our test.
|
||||||
"vptestmd\t%%zmm7,\t%%zmm8,\t%%k2\n\t" // perform our test.
|
"vptestmd\t%%zmm7,\t%%zmm8,\t%%k2\n\t" // perform our test.
|
||||||
"vmovdqa32\t\t(%%r12)%{uint8%},\t%%zmm0\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
"vmovdqa32\t\t(%%r12)%{uint8%},\t%%zmm0\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||||
"vpandd\t%%zmm0,\t%%zmm2,\t%%zmm4\n\t" // apply a mask, storing the low four bits of vector zmm0 into zmm4.
|
"vpandd\t%%zmm0,\t%%zmm2,\t%%zmm4\n\t" // apply a mask, storing the low four bits of vector zmm0 into zmm4.
|
||||||
"vpaddd\t%%zmm4,%%zmm9,%%zmm4%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
"vpaddd\t%%zmm4,%%zmm9,%%zmm4%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||||
"vmovdqa32\t\t%%zmm4%{uint8%},\t(%%r8)\n\t" // save our result.
|
"vmovdqa32\t\t%%zmm4%{uint8%},\t(%%r8)\n\t" // save our result.
|
||||||
"vmovdqa32\t\t16(%%r12)%{uint8%},\t%%zmm1\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
"vmovdqa32\t\t16(%%r12)%{uint8%},\t%%zmm1\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||||
"vpandd\t%%zmm1,\t%%zmm2,\t%%zmm5\n\t" // apply a mask, storing the next low four bits of vector zmm1 into zmm5.
|
"vpandd\t%%zmm1,\t%%zmm2,\t%%zmm5\n\t" // apply a mask, storing the next low four bits of vector zmm1 into zmm5.
|
||||||
"vpaddd\t%%zmm5,%%zmm9,%%zmm5%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
"vpaddd\t%%zmm5,%%zmm9,%%zmm5%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||||
"vmovdqa32\t\t%%zmm5%{uint8%},\t16(%%r8)\n\t" // save our result.
|
"vmovdqa32\t\t%%zmm5%{uint8%},\t16(%%r8)\n\t" // save our result.
|
||||||
"add\t$32,\t%%r8\n\t"
|
"add\t$32,\t%%r8\n\t"
|
||||||
"cmp\t$4,\t%%ecx\n\t"
|
"cmp\t$4,\t%%ecx\n\t"
|
||||||
"vpslld\t$1,\t%%zmm8,\t%%zmm8\n\t" // select which bit we want to test for.
|
"vpslld\t$1,\t%%zmm8,\t%%zmm8\n\t" // select which bit we want to test for.
|
||||||
"vptestmd\t%%zmm6,\t%%zmm8,\t%%k1\n\t" // perform our test.
|
"vptestmd\t%%zmm6,\t%%zmm8,\t%%k1\n\t" // perform our test.
|
||||||
"vptestmd\t%%zmm7,\t%%zmm8,\t%%k2\n\t" // perform our test.
|
"vptestmd\t%%zmm7,\t%%zmm8,\t%%k2\n\t" // perform our test.
|
||||||
"vpsrld\t$4,\t%%zmm0,\t%%zmm4\n\t" // load our even 4 bit sequence into zmm4.
|
"vpsrld\t$4,\t%%zmm0,\t%%zmm4\n\t" // load our even 4 bit sequence into zmm4.
|
||||||
"vpaddd\t%%zmm4,%%zmm9,%%zmm4%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
"vpaddd\t%%zmm4,%%zmm9,%%zmm4%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||||
"vmovdqa32\t\t%%zmm4%{uint8%},\t(%%r8)\n\t" // save our result.
|
"vmovdqa32\t\t%%zmm4%{uint8%},\t(%%r8)\n\t" // save our result.
|
||||||
"vpsrld\t$4,\t%%zmm1,\t%%zmm5\n\t" // load our even 4 bit sequence into zmm5.
|
"vpsrld\t$4,\t%%zmm1,\t%%zmm5\n\t" // load our even 4 bit sequence into zmm5.
|
||||||
"vpaddd\t%%zmm5,%%zmm9,%%zmm5%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
"vpaddd\t%%zmm5,%%zmm9,%%zmm5%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||||
"vmovdqa32\t\t%%zmm5%{uint8%},\t16(%%r8)\n\t" // save our result.
|
"vmovdqa32\t\t%%zmm5%{uint8%},\t16(%%r8)\n\t" // save our result.
|
||||||
"je\t2f\n\t"
|
"je\t2f\n\t"
|
||||||
"vpslld\t$1,\t%%zmm8,\t%%zmm8\n\t" // select which bit we want to test for.
|
"vpslld\t$1,\t%%zmm8,\t%%zmm8\n\t" // select which bit we want to test for.
|
||||||
"add\t$32,\t%%r12\n\t"
|
"add\t$32,\t%%r12\n\t"
|
||||||
"add\t$32,\t%%r8\n\t"
|
"add\t$32,\t%%r8\n\t"
|
||||||
"jmp\t1b\n\t"
|
"jmp\t1b\n\t"
|
||||||
"2:"
|
"2:"
|
||||||
: [DST] "+r" (dst)
|
: [DST] "+r" (dst)
|
||||||
: [SRC4] "r" (q4),
|
: [SRC4] "r" (q4),
|
||||||
[SRC1] "r" (q1),
|
[SRC1] "r" (q1),
|
||||||
[MASK] "m" (lowmask),
|
[MASK] "m" (lowmask),
|
||||||
[M] "m" (m),
|
[M] "m" (m),
|
||||||
[ALL] "m" (allmask),
|
[ALL] "m" (allmask),
|
||||||
[BIT5] "m" (bit5)
|
[BIT5] "m" (bit5)
|
||||||
: "zmm0", "zmm1", "zmm2", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "cc", "ecx", "k1", "k2", "r12", "r8", "memory"
|
: "zmm0", "zmm1", "zmm2", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "cc", "ecx", "k1", "k2", "r12", "r8", "memory"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// A function for getting the dot product of two vectors, one of 5 bit resolution, and one of 8.
|
// A function for getting the dot product of two vectors, one of 5 bit resolution, and one of 8.
|
||||||
|
@ -185,37 +185,37 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
int8x16_t q8copy [QK_K];
|
int8x16_t q8copy [QK_K];
|
||||||
int32x16_t aux32;
|
int32x16_t aux32;
|
||||||
uint8x16_t q4copyvec [QK_K/32];
|
uint8x16_t q4copyvec [QK_K/32];
|
||||||
uint8x16_t aux8 [QK_K/16];
|
uint8x16_t aux8 [QK_K/16];
|
||||||
|
|
||||||
// Fill in our 8 bit vector from y[]. required, because there is no good way to align members of y[], And I haven't mastered unaligned assembly yet...
|
// Fill in our 8 bit vector from y[]. required, because there is no good way to align members of y[], And I haven't mastered unaligned assembly yet...
|
||||||
memcpy (q8copy, y[i].qs, QK_K);
|
memcpy (q8copy, y[i].qs, QK_K);
|
||||||
|
|
||||||
// Fill in our 4 bit vector from x[]. required, because there is no good way to align members of x[], And I haven't mastered unaligned assembly yet...
|
// Fill in our 4 bit vector from x[]. required, because there is no good way to align members of x[], And I haven't mastered unaligned assembly yet...
|
||||||
memcpy (q4copyvec, x[i].qs, QK_K/2);
|
memcpy (q4copyvec, x[i].qs, QK_K/2);
|
||||||
|
|
||||||
// combine our 4 and 1 bit vector sets into an 8 bit value.
|
// combine our 4 and 1 bit vector sets into an 8 bit value.
|
||||||
GGML_5bit_Unpack(q4copyvec, x[i].qh, aux8);
|
GGML_5bit_Unpack(q4copyvec, x[i].qh, aux8);
|
||||||
|
|
||||||
// extract scales and mins..
|
// extract scales and mins..
|
||||||
memcpy(utmp, x[i].scales, 12);
|
memcpy(utmp, x[i].scales, 12);
|
||||||
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||||
const uint32_t uaux = utmp[1] & kmask1;
|
const uint32_t uaux = utmp[1] & kmask1;
|
||||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||||
utmp[2] = uaux;
|
utmp[2] = uaux;
|
||||||
utmp[0] &= kmask1;
|
utmp[0] &= kmask1;
|
||||||
|
|
||||||
// FIXME: while comparing FMA output to the original output, the original had an error. hunt it down.
|
// FIXME: while comparing FMA output to the original output, the original had an error. hunt it down.
|
||||||
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16(q8copy, aux8, scales, &aux32);
|
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16(q8copy, aux8, scales, &aux32);
|
||||||
|
|
||||||
int sumi = 0;
|
int sumi = 0;
|
||||||
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
for (int l = 0; l < GGML_F32_EPR; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l];
|
for (int l = 0; l < GGML_F32_EPR; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l];
|
||||||
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
||||||
sumf -= dmin * sumi;
|
sumf -= dmin * sumi;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int l = 0; l < GGML_F32_EPR; ++l) sumf += ((float *)&sums)[l];
|
for (int l = 0; l < GGML_F32_EPR; ++l) sumf += ((float *)&sums)[l];
|
||||||
|
|
172
ggml-phi-knc.c
172
ggml-phi-knc.c
|
@ -23,11 +23,11 @@ inline static void GGML_F32x16_VEC_ZERO(float32x16_t *target)
|
||||||
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
||||||
|
|
||||||
__asm__ __volatile__ (
|
__asm__ __volatile__ (
|
||||||
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our value.
|
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our value.
|
||||||
"vmovnraps\t\t%%zmm8,\t%[RES]\n\t"
|
"vmovnraps\t\t%%zmm8,\t%[RES]\n\t"
|
||||||
: [RES] "+m" (*target)
|
: [RES] "+m" (*target)
|
||||||
: [Z] "m" (zero)
|
: [Z] "m" (zero)
|
||||||
: "zmm8");
|
: "zmm8");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,73 +37,73 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x
|
||||||
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
uint8_t zero[4] __attribute__((aligned(64))) = {0,0,0,0};
|
||||||
|
|
||||||
__asm__ __volatile__ (
|
__asm__ __volatile__ (
|
||||||
"mov\t%[ITER],%%r8\n\t" // how many register sized chunks are we responsible for
|
"mov\t%[ITER],%%r8\n\t" // how many register sized chunks are we responsible for
|
||||||
"mov\t%[VEC1],%%r10\n\t" // where do we start work in mvec1?
|
"mov\t%[VEC1],%%r10\n\t" // where do we start work in mvec1?
|
||||||
"mov\t%[VEC2],%%r12\n\t" // where do we start work in mvec2?
|
"mov\t%[VEC2],%%r12\n\t" // where do we start work in mvec2?
|
||||||
"cmp\t$1,%[CLR]\n\t" // should we clear the sum before we start?
|
"cmp\t$1,%[CLR]\n\t" // should we clear the sum before we start?
|
||||||
"jne\t4f\n\t"
|
"jne\t4f\n\t"
|
||||||
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm0\n\t" // if so, use an upscaling operator to do it.
|
"vbroadcastf32x4\t%[Z]%{uint8%},\t%%zmm0\n\t" // if so, use an upscaling operator to do it.
|
||||||
"vprefetchnta\t(%%r10)\n\t"
|
"vprefetchnta\t(%%r10)\n\t"
|
||||||
"vprefetchnta\t(%%r12)\n\t"
|
"vprefetchnta\t(%%r12)\n\t"
|
||||||
"vprefetch1\t128(%%r10)\n\t"
|
"vprefetch1\t128(%%r10)\n\t"
|
||||||
"vprefetch1\t128(%%r12)\n\t"
|
"vprefetch1\t128(%%r12)\n\t"
|
||||||
"vprefetch1\t256(%%r10)\n\t"
|
"vprefetch1\t256(%%r10)\n\t"
|
||||||
"vprefetch1\t256(%%r12)\n\t"
|
"vprefetch1\t256(%%r12)\n\t"
|
||||||
"vprefetch1\t384(%%r10)\n\t"
|
"vprefetch1\t384(%%r10)\n\t"
|
||||||
"vprefetch1\t384(%%r12)\n\t"
|
"vprefetch1\t384(%%r12)\n\t"
|
||||||
"vprefetch1\t512(%%r10)\n\t"
|
"vprefetch1\t512(%%r10)\n\t"
|
||||||
"vprefetch1\t512(%%r12)\n\t"
|
"vprefetch1\t512(%%r12)\n\t"
|
||||||
"jmp\t1f\n\t"
|
"jmp\t1f\n\t"
|
||||||
"4:\n\t"
|
"4:\n\t"
|
||||||
"vprefetch0\t(%[RES])\n\t"
|
"vprefetch0\t(%[RES])\n\t"
|
||||||
"vmovaps\t\t(%[RES]),\t%%zmm0\n\t" // otherwise, load our inital state from sum..
|
"vmovaps\t\t(%[RES]),\t%%zmm0\n\t" // otherwise, load our inital state from sum..
|
||||||
"vprefetchnta\t(%%r10)\n\t"
|
"vprefetchnta\t(%%r10)\n\t"
|
||||||
"vprefetchnta\t(%%r12)\n\t"
|
"vprefetchnta\t(%%r12)\n\t"
|
||||||
"1:\n\t"
|
"1:\n\t"
|
||||||
"cmp\t$3,\t%%r8\n\t" // Compare iterations to three.
|
"cmp\t$3,\t%%r8\n\t" // Compare iterations to three.
|
||||||
"jnae\t6f\n\t" // If there are not three iterations left, jump to label 6.
|
"jnae\t6f\n\t" // If there are not three iterations left, jump to label 6.
|
||||||
"vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors.
|
"vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors.
|
||||||
"vmovaps\t\t(%%r12),\t%%zmm2\n\t"
|
"vmovaps\t\t(%%r12),\t%%zmm2\n\t"
|
||||||
"sub\t$3,\t%%r8\n\t" // Decrement iterations
|
"sub\t$3,\t%%r8\n\t" // Decrement iterations
|
||||||
"vprefetchnta\t192(%%r10)\n\t" // prefetch the next float32x16_t block (192 bytes ahead)
|
"vprefetchnta\t192(%%r10)\n\t" // prefetch the next float32x16_t block (192 bytes ahead)
|
||||||
"vprefetchnta\t192(%%r12)\n\t"
|
"vprefetchnta\t192(%%r12)\n\t"
|
||||||
"vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors.
|
"vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors.
|
||||||
"vmovaps\t\t64(%%r12),\t%%zmm4\n\t"
|
"vmovaps\t\t64(%%r12),\t%%zmm4\n\t"
|
||||||
"vprefetch1\t320(%%r10)\n\t" // prefetch the block after the block after the next float32x16_t block (320 bytes ahead)
|
"vprefetch1\t320(%%r10)\n\t" // prefetch the block after the block after the next float32x16_t block (320 bytes ahead)
|
||||||
"vprefetch1\t320(%%r12)\n\t"
|
"vprefetch1\t320(%%r12)\n\t"
|
||||||
"vmovaps\t\t128(%%r10),\t%%zmm5\n\t" // Load two vectors.
|
"vmovaps\t\t128(%%r10),\t%%zmm5\n\t" // Load two vectors.
|
||||||
"vmovaps\t\t128(%%r12),\t%%zmm6\n\t"
|
"vmovaps\t\t128(%%r12),\t%%zmm6\n\t"
|
||||||
"vprefetch1\t576(%%r10)\n\t"
|
"vprefetch1\t576(%%r10)\n\t"
|
||||||
"vprefetch1\t576(%%r12)\n\t"
|
"vprefetch1\t576(%%r12)\n\t"
|
||||||
"vprefetch1\t704(%%r10)\n\t"
|
"vprefetch1\t704(%%r10)\n\t"
|
||||||
"vprefetch1\t704(%%r12)\n\t"
|
"vprefetch1\t704(%%r12)\n\t"
|
||||||
"add\t$192,\t%%r10\n\t" // Move to the next float32x16_t block (192 bytes ahead)
|
"add\t$192,\t%%r10\n\t" // Move to the next float32x16_t block (192 bytes ahead)
|
||||||
"add\t$192,\t%%r12\n\t"
|
"add\t$192,\t%%r12\n\t"
|
||||||
"vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"jmp\t1b\n\t" // Jump back to the start of the loop
|
"jmp\t1b\n\t" // Jump back to the start of the loop
|
||||||
"6:\n\t" // we know we are near the tail. handle 2, 1, and 0 cases.
|
"6:\n\t" // we know we are near the tail. handle 2, 1, and 0 cases.
|
||||||
"cmp\t$0,\t%%r8\n\t" // Compare iterations to zero
|
"cmp\t$0,\t%%r8\n\t" // Compare iterations to zero
|
||||||
"je\t2f\n\t" // Jump to label 2 if zero (end of loop)
|
"je\t2f\n\t" // Jump to label 2 if zero (end of loop)
|
||||||
"cmp\t$1,\t%%r8\n\t" // Compare iterations to one
|
"cmp\t$1,\t%%r8\n\t" // Compare iterations to one
|
||||||
"vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors.
|
"vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors.
|
||||||
"vmovaps\t\t(%%r12),\t%%zmm2\n\t"
|
"vmovaps\t\t(%%r12),\t%%zmm2\n\t"
|
||||||
"vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"je\t2f\n\t" // Jump to label 3 if one (end of loop)
|
"je\t2f\n\t" // Jump to label 3 if one (end of loop)
|
||||||
// No compare. we must be two.
|
// No compare. we must be two.
|
||||||
"vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors.
|
"vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors.
|
||||||
"vmovaps\t\t64(%%r12),\t%%zmm4\n\t"
|
"vmovaps\t\t64(%%r12),\t%%zmm4\n\t"
|
||||||
"vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"2:\n\t" // Label for loop end
|
"2:\n\t" // Label for loop end
|
||||||
"vmovnraps\t\t%%zmm0,\t(%[RES])\n\t" // save our results.
|
"vmovnraps\t\t%%zmm0,\t(%[RES])\n\t" // save our results.
|
||||||
: [RES] "+r" (sumvec)
|
: [RES] "+r" (sumvec)
|
||||||
: [ITER] "r" (iterations),
|
: [ITER] "r" (iterations),
|
||||||
[VEC1] "r" (mvec1),
|
[VEC1] "r" (mvec1),
|
||||||
[VEC2] "r" (mvec2),
|
[VEC2] "r" (mvec2),
|
||||||
[CLR] "r" (clear),
|
[CLR] "r" (clear),
|
||||||
[Z] "m" (zero)
|
[Z] "m" (zero)
|
||||||
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "cc", "memory", "r8", "r10", "r12");
|
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "cc", "memory", "r8", "r10", "r12");
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: x and y inputs must be __attribute__((aligned(64)));
|
// NOTE: x and y inputs must be __attribute__((aligned(64)));
|
||||||
|
@ -119,24 +119,24 @@ void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restri
|
||||||
|
|
||||||
// add the leftovers, that could not be handled by the vector loop.
|
// add the leftovers, that could not be handled by the vector loop.
|
||||||
if ( n - np != 0 )
|
if ( n - np != 0 )
|
||||||
{
|
{
|
||||||
// our extended last part of x.
|
// our extended last part of x.
|
||||||
float32x16_t v1;
|
float32x16_t v1;
|
||||||
GGML_F32x16_VEC_ZERO(&v1);
|
GGML_F32x16_VEC_ZERO(&v1);
|
||||||
// our extended last part of y.
|
// our extended last part of y.
|
||||||
float32x16_t v2;
|
float32x16_t v2;
|
||||||
GGML_F32x16_VEC_ZERO(&v2);
|
GGML_F32x16_VEC_ZERO(&v2);
|
||||||
|
|
||||||
memcpy(&v1, &x[np], (n - np)*sizeof(float));
|
memcpy(&v1, &x[np], (n - np)*sizeof(float));
|
||||||
memcpy(&v2, &y[np], (n - np)*sizeof(float));
|
memcpy(&v2, &y[np], (n - np)*sizeof(float));
|
||||||
|
|
||||||
GGML_F32x16_VEC_FMA(&v1,
|
GGML_F32x16_VEC_FMA(&v1,
|
||||||
&v2,
|
&v2,
|
||||||
&sum, 1, 0);
|
&sum, 1, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce sum, and store it in s.
|
// reduce sum, and store it in s.
|
||||||
for (uint32_t i=0; i <GGML_F32_EPR; ++i)
|
for (uint32_t i=0; i <GGML_F32_EPR; ++i)
|
||||||
*s+=((float *)&sum)[i];
|
*s+=((float *)&sum)[i];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue