further optimizations. 0.99 tokens per second.
This commit is contained in:
parent
d966ac2ebe
commit
c3d438bce2
4 changed files with 379 additions and 131 deletions
2
Makefile
2
Makefile
|
@ -760,7 +760,7 @@ bench-phi-knc.s: bench-phi-knc.c
|
|||
ggml-phi-knc.s: ggml-phi-knc.c
|
||||
$(CC) $(CFLAGS) -S $< -o $(call GET_ASM_FILE, $<)
|
||||
|
||||
bench-phi-knc: bench-phi-knc.c ggml-phi-knc.o
|
||||
bench-phi-knc: bench-phi-knc.c ggml-phi-knc.o ggml-phi-knc-dot_q5_K_q8_K.o
|
||||
$(CC) $(CFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CC) $(CFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
|
|
175
bench-phi-knc.c
175
bench-phi-knc.c
|
@ -1,33 +1,52 @@
|
|||
/* bench-phi-knc.c: benchmarks and tests for the Xeon PHI Knights Corner optimizations. */
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <unistd.h> /*for CLOCK_REALTIME? */
|
||||
|
||||
/* For CLOCK_REALTIME? */
|
||||
#include <unistd.h>
|
||||
#include <time.h>
|
||||
|
||||
/* For memcpy */
|
||||
#include <string.h>
|
||||
|
||||
/* include the increasingly inacurately named header for our F32 dot product code. */
|
||||
#include "ggml-phi-knc.h"
|
||||
|
||||
#define MAXVEC 1024768
|
||||
#define RUNTOTAL 12
|
||||
#define RUNS
|
||||
/* include the header for our Q8K_Q5K dot product code. */
|
||||
#include "ggml-phi-knc-dot_q5_K_q8_K.h"
|
||||
|
||||
// largest Float32 vectors to get the dot product of.
|
||||
#define F32_MAXVEC 1024768
|
||||
// how many benchmarks we will run in total.
|
||||
#define F32_RUNCOUNT 12
|
||||
#define F32_ITEMS_PER_RUN {10, 16, 17, 32, 33, 48, 49, 64, 65, 80, 81, 1024768}
|
||||
|
||||
int main(void)
|
||||
{
|
||||
struct timespec start, middle, end;
|
||||
double vector_time;
|
||||
double scalar_time;
|
||||
float scalar = 0.0f;
|
||||
float vector = 0.0f;
|
||||
int vecRuns[RUNTOTAL] = {10, 16, 17, 32, 33, 48, 49, 64, 65, 80, 81, 1024768};
|
||||
int vecRuns[F32_RUNCOUNT] = F32_ITEMS_PER_RUN;
|
||||
|
||||
for (uint32_t runCount = 0; runCount < RUNTOTAL; ++runCount)
|
||||
// seed the random number generator.
|
||||
srand(time(NULL));
|
||||
|
||||
// Run benchmarks for our F32 dot product functions. Benchmark them against a naieve implementation.
|
||||
for (uint8_t runCount = 0; runCount < F32_RUNCOUNT; ++runCount)
|
||||
{
|
||||
struct timespec start, middle, end;
|
||||
double vector_time;
|
||||
double scalar_time;
|
||||
float scalar = 0.0f;
|
||||
float vector = 0.0f;
|
||||
|
||||
// Generate random input vector of [-1, 1] values.
|
||||
float vec1[MAXVEC] __attribute__((aligned(64)));
|
||||
float vec1[F32_MAXVEC] __attribute__((aligned(64)));
|
||||
for (int i = 0; i < vecRuns[runCount]; i++)
|
||||
vec1[i] = 2 * (0.5 - rand() / (float)RAND_MAX);
|
||||
|
||||
// Generate a second random input vector of [-1, 1] values.
|
||||
float vec2[MAXVEC] __attribute__((aligned(64)));
|
||||
float vec2[F32_MAXVEC] __attribute__((aligned(64)));
|
||||
for (int i = 0; i < vecRuns[runCount]; i++)
|
||||
vec2[i] = 2 * (0.5 - rand() / (float)RAND_MAX);
|
||||
|
||||
|
@ -60,5 +79,135 @@ int main(void)
|
|||
|
||||
fflush(stdout);
|
||||
|
||||
// Generate a random input vector of 256 4 bit values.
|
||||
uint8x16_t q4[8];
|
||||
uint8_t * q4ptr = (uint8_t *)q4;
|
||||
for (int i = 0; i < 128; i++)
|
||||
q4ptr[i] = rand() && 0xFF;
|
||||
|
||||
// Generate a random input vector of 256 1 bit values.
|
||||
uint8x16_t q1[2];
|
||||
uint8_t * q1ptr = (uint8_t *)q1;
|
||||
for (int i = 0; i < 32; i++)
|
||||
q1ptr[i] = rand() && 0xFF;
|
||||
|
||||
// Get our reference, unshifted result.
|
||||
uint8x16_t q5[16];
|
||||
GGML_5bit_Unpack_Unaligned(q4, (uint8_t *)q1, q5);
|
||||
|
||||
printf("successfully got a Q5.\n");
|
||||
|
||||
// Perform alignment tests, for GGML_5bit_Unpack_Unaligned.
|
||||
// Try to run GGML_5bit_Unpack_Unaligned with all possible misalignments, and get it to fail.
|
||||
for (uint8_t shiftCount = 1; shiftCount < 16; ++shiftCount)
|
||||
{
|
||||
uint8x16_t q5new[16];
|
||||
uint8x16_t q4Shifted[9];
|
||||
|
||||
// create an off-by-shiftCount copy of q4.
|
||||
q4ptr = ((uint8_t *)q4Shifted) + shiftCount;
|
||||
memcpy (q4ptr, q4, 128);
|
||||
|
||||
// call the unaligned form of this function:
|
||||
GGML_5bit_Unpack_Unaligned((uint8x16_t *)q4ptr, (uint8_t *)q1, q5new);
|
||||
|
||||
for (uint32_t byteCount = 0; byteCount < 256; ++byteCount)
|
||||
{
|
||||
if ( ((uint8_t *)q5new)[byteCount] != ((uint8_t *)q5)[byteCount] )
|
||||
{
|
||||
printf("whoops!\nshiftCount: %d\nbyteCount: %d\n", shiftCount, byteCount);
|
||||
exit (-1);
|
||||
}
|
||||
}
|
||||
|
||||
printf("Got a Q5 offset by %d\n", shiftCount);
|
||||
}
|
||||
|
||||
// Generate a random input vector of 256 8 bit values.
|
||||
int8x16_t q8[16];
|
||||
int8_t * q8ptr = (int8_t *)q8;
|
||||
for (int i = 0; i < 256; i++)
|
||||
q8ptr[i] = rand() && 0xFF;
|
||||
|
||||
// Generate eight random scales, one for each pair of sums.
|
||||
uint8_t scale[8];
|
||||
for (int i = 0; i < 8; i++)
|
||||
scale[i] = rand() && 0xFF;
|
||||
|
||||
// Generate a random X scale.
|
||||
float rndScaleX = 2 * (0.5 - rand() / (float)RAND_MAX);
|
||||
ggml_fp16_t scaleX = GGML_PHI_FP32_TO_FP16(rndScaleX);
|
||||
|
||||
// Display the random X scale. Verifies FP32_TO_FP16_TO_FP32 is working.
|
||||
printf("rndScaleX: %f\n", rndScaleX);
|
||||
printf("scaleX: %x\n", scaleX);
|
||||
printf("newScaleX: %f\n", GGML_PHI_FP16_TO_FP32(scaleX));
|
||||
|
||||
// Generate a random Y scale.
|
||||
float scaleY = 2 * (0.5 - rand() / (float)RAND_MAX);
|
||||
printf("scaleY: %f\n", scaleY);
|
||||
|
||||
// Create a place for our golden result.
|
||||
float32x16_t res;
|
||||
|
||||
// Clear res.
|
||||
GGML_F32x16_VEC_ZERO(&res);
|
||||
|
||||
// Generate an initial result, to compare to.
|
||||
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned (q8, q5, scale, scaleX, scaleY, &res);
|
||||
|
||||
// Generate a sum of the result.
|
||||
float sum = 0.0f;
|
||||
for (int l = 0; l < 16; ++l) sum += ((float *)&res)[l];
|
||||
|
||||
printf("Got a res: %f\n", sum);
|
||||
|
||||
// Perform alignment tests, for GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned.
|
||||
// try to run GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned with all possible mis-alignments, and get it to fail.
|
||||
for (uint8_t shiftCount = 1; shiftCount < 16; ++shiftCount)
|
||||
{
|
||||
float32x16_t resNew1;
|
||||
int8x16_t q8Shifted[17];
|
||||
|
||||
// Create an off-by-shiftCount copy of q8.
|
||||
q8ptr = ((int8_t *)q8Shifted)+shiftCount;
|
||||
memcpy (q8ptr, q8, 256);
|
||||
|
||||
// Clear resNew.
|
||||
GGML_F32x16_VEC_ZERO(&resNew1);
|
||||
|
||||
// Call the unaligned form of this function:
|
||||
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned ((int8x16_t *)q8ptr, q5, scale, scaleX, scaleY, &resNew1);
|
||||
|
||||
// check the result against our reference.
|
||||
for (uint32_t floatCount = 0; floatCount < 64; ++floatCount)
|
||||
{
|
||||
if ( ((int8_t *)&resNew1)[floatCount] != ((int8_t *)&res)[floatCount] )
|
||||
{
|
||||
printf("whoops!\nshiftCount: %d\nfloatCount: %d\n", shiftCount, floatCount);
|
||||
for (uint32_t row = 0; row < 16 ; ++row)
|
||||
{
|
||||
for (int col1 = 0; col1 < 4; ++col1)
|
||||
{
|
||||
printf("%2.2x\t", ((int8_t *)&resNew1)[(4*row)+col1]);
|
||||
}
|
||||
printf(" vs ");
|
||||
for (int col2 = 0; col2 < 4; ++col2)
|
||||
{
|
||||
printf("%2.2x\t", ((int8_t *)&res)[(4*row)+col2]);
|
||||
}
|
||||
printf ("\n");
|
||||
}
|
||||
exit (-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a sum of our new result.
|
||||
float sumf = 0.0f;
|
||||
for (int l = 0; l < 16; ++l) sumf += ((float *)&resNew1)[l];
|
||||
|
||||
printf("Got a res from a Q8 offset by %d: %f\n", ((int)q8ptr) & 0x3F, sumf);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -15,135 +15,227 @@
|
|||
// For block_q5_K and block_q8_K.
|
||||
#include "ggml-common.h"
|
||||
|
||||
// This SIMD unit can work with 32 float32s at once.
|
||||
#define GGML_F32_STEP 32
|
||||
// We can fit 16 of these float32s in a single vector register.
|
||||
// For our vector types.
|
||||
#include "ggml-phi-knc-dot_q5_K_q8_K.h"
|
||||
|
||||
// We can fit 16 float32s in a single vector register.
|
||||
#define GGML_F32_EPR 16
|
||||
|
||||
/* we force an alignment, because i haven't written unaligned forms of the assembly functions, yet.. */
|
||||
typedef float float32x16_t __attribute__((vector_size (64), aligned(64)));
|
||||
typedef int8_t int8x16_t __attribute__((vector_size (16), aligned(16)));
|
||||
typedef uint8_t uint8x16_t __attribute__((vector_size (16), aligned(16)));
|
||||
typedef int32_t int32x16_t __attribute__((vector_size (64), aligned(64)));
|
||||
|
||||
/* A forward declaration, to keep GCC happy. */
|
||||
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc);
|
||||
|
||||
/* clear a vector of 16 floats. */
|
||||
inline static void GGML_F32x16_VEC_ZERO(float32x16_t *target)
|
||||
/* Clear a vector of 16 floats. */
|
||||
void GGML_F32x16_VEC_ZERO(float32x16_t *target)
|
||||
{
|
||||
uint8_t zero=0;
|
||||
|
||||
__asm__ __volatile__ (
|
||||
"vbroadcastss\t%[Z]%{uint8%},\t%%zmm8\n\t" // use an upscaling operator to clear our register.
|
||||
"vmovaps\t\t%%zmm8,\t%[RES]\n\t"
|
||||
"vbroadcastss\t%[Z]%{uint8%},\t%%zmm0\n\t" // use an upscaling operator to clear our register.
|
||||
"vmovaps\t\t%%zmm0,\t%[RES]\n\t"
|
||||
: [RES] "+m" (*target)
|
||||
: [Z] "m" (zero)
|
||||
: "zmm8", "memory");
|
||||
: "zmm0", "memory");
|
||||
|
||||
}
|
||||
|
||||
// This function perform two multiplies of an I8x16 and an I8x16 vector into two I16x16 vectors. then does an FMA on the scaled result of multiplying the two I16x16 vectors, adding the result into an I32x16.
|
||||
/* convert a FP16 to a FP32. */
|
||||
float GGML_PHI_FP16_TO_FP32(ggml_fp16_t src)
|
||||
{
|
||||
// we only care aboun one result.
|
||||
uint32_t mask=0x0001;
|
||||
|
||||
// we declare this as an array, so it ends up in a different memory section.
|
||||
float f32[1] __attribute__((aligned(64)));
|
||||
|
||||
__asm__ __volatile__ (
|
||||
"kmov\t%[M],\t%%k1\n\t"
|
||||
"vbroadcastss\t%[SRC]%{float16%},\t%%zmm1%{%%k1%}\n\t"
|
||||
"vmovaps\t\t%%zmm1,\t%[DST]%{%%k1%}\n\t"
|
||||
: [DST] "+m" (f32)
|
||||
: [SRC] "m" (src),
|
||||
[M] "r" (mask)
|
||||
: "zmm1", "memory", "k1");
|
||||
return f32[0];
|
||||
}
|
||||
|
||||
/* convert a FP32 to a FP16. */
|
||||
ggml_fp16_t GGML_PHI_FP32_TO_FP16(float src)
|
||||
{
|
||||
uint32_t mask=0x0001;
|
||||
|
||||
// we declare this as an array, so it ends up in a different memory section.
|
||||
ggml_fp16_t f16[1] __attribute__((aligned(64)));
|
||||
|
||||
__asm__ __volatile__ (
|
||||
"kmov\t%[M],\t%%k1\n\t"
|
||||
"vbroadcastss\t%[SRC],\t%%zmm2%{%%k1%}\n\t"
|
||||
"vmovaps\t\t%%zmm2%{float16%},\t%[DST]%{%%k1%}\n\t"
|
||||
: [DST] "+m" (f16)
|
||||
: [SRC] "m" (src),
|
||||
[M] "r" (mask)
|
||||
: "zmm2", "memory", "k1");
|
||||
return f16[0];
|
||||
}
|
||||
|
||||
|
||||
// This function perform two multiplies of an I8x16 and an I8x16 vector into two I16x16 vectors. then does an FMA on the scaled result of multiplying the two I16x16 vectors, adding the result into an I32x16. When done, it multiplies this I32x16 by a float, returning a F32x16.
|
||||
// it loops 8 times. well, actually four, with an unroll.
|
||||
inline static void GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16 (int8x16_t *src11, uint8x16_t *src21, const uint8_t *scale, int32x16_t *res)
|
||||
void GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned (const int8x16_t *q8, uint8x16_t *q5, const uint8_t *scale, ggml_fp16_t scaleX, float scaleY, float32x16_t *res)
|
||||
{
|
||||
uint8_t zero = 0;
|
||||
uint64_t q8offset=((uint64_t) q8) & 0x3f;
|
||||
|
||||
__asm__ __volatile__ (
|
||||
"vprefetche0\t(%[SRC11])\n\t"
|
||||
"vprefetche0\t(%[SRC21])\n\t"
|
||||
"vprefetche0\t(%[SCALE])\n\t"
|
||||
"mov\t$0,\t%%ecx\n\t"
|
||||
"mov\t%[SRC11],\t%%r12\n\t"
|
||||
"mov\t%[SRC21],\t%%r8\n\t"
|
||||
"vprefetchenta\t(%[RES])\n\t"
|
||||
"vprefetch0\t64(%[SCALE])\n\t"
|
||||
"vprefetch0\t(%[SRC8])\n\t"
|
||||
"vprefetch0\t64(%[SRC8])\n\t"
|
||||
"vprefetch0\t(%[SRC5])\n\t"
|
||||
"mov\t%[SRC8],\t%%r11\n\t" // use r11 to store the address for vloadunpackld.
|
||||
"mov\t%[SRC5],\t%%r8\n\t"
|
||||
"mov\t%[SCALE],\t%%r9\n\t"
|
||||
"vpbroadcastd\t%[Z]%{uint8%},\t%%zmm7\n\t" // empty our result.
|
||||
"mov\t$0,\t%%ecx\n\t"
|
||||
"mov\t%[SRC8],\t%%r15\n\t" // use r12-r15 to store the addresses for vloadunpackhd.
|
||||
"mov\t%[SRC8],\t%%r14\n\t"
|
||||
"mov\t%[SRC8],\t%%r13\n\t"
|
||||
"mov\t%[SRC8],\t%%r12\n\t"
|
||||
"mov\t%[OFFSET],\t%%r10\n\t"
|
||||
"cmp\t$32,%%r10\n\t" // Examine OFFSET, and decide which (if any) of the vloadunpackhd invocations needs to be increaned by 64.
|
||||
"jl\t20f\n\t"
|
||||
"cmp\t$48,%%r10\n\t"
|
||||
"jl\t21f\n\t"
|
||||
"add\t$64,%%r12\n\t" // greater than 48.
|
||||
"jmp\t18f\n\t"
|
||||
"21:\n\t"
|
||||
"add\t$64,%%r13\n\t" // between 48 and 32.
|
||||
"jmp\t18f\n\t"
|
||||
"20:\n\t" // less than 32...
|
||||
"cmp\t$16,%%r10\n\t"
|
||||
"jz\t18f\n\t" // zero
|
||||
"jl\t23f\n\t"
|
||||
"add\t$64,%%r14\n\t" // between 32 and 16...
|
||||
"jmp\t18f\n\t"
|
||||
"23:\n\t"
|
||||
"add\t$64,%%r15\n\t" // between 16 and zero..
|
||||
"18:\n\t"
|
||||
"vbroadcastss\t%[SCALEY],\t%%zmm3\n\t" // load the scale factors coresponding to the two input vectors.
|
||||
"vbroadcastss\t%[SCALEX]%{float16%},\t%%zmm4\n\t"
|
||||
"vmulps\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // prepare the factor we're going to multiply the result by..
|
||||
"vmovaps\t\t(%[RES]),\t%%zmm6\n\t" // load our inital state from sum..
|
||||
"vpbroadcastd\t%[Z]%{uint8%},\t%%zmm7\n\t" // empty our result.
|
||||
"1:\n\t"
|
||||
"inc\t%%ecx\n\t" // we are in our loop, increment our counter.
|
||||
"cmp\t$4,\t%%ecx\n\t" // see if this is our last run-through.
|
||||
"vmovdqa32\t\t(%%r12)%{sint8%},\t%%zmm0\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vmovdqa32\t\t(%%r8)%{uint8%},\t%%zmm1\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm0,\t%%zmm1,\t%%zmm2\n\t" // perform our 64 bit multiply, low side.
|
||||
"vpbroadcastd\t(%%r9)%{uint8%},\t%%zmm6\n\t" // load the item we will be multiplying by.
|
||||
"vpmadd231d\t%%zmm2,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"vmovdqa32\t\t16(%%r12)%{sint8%},\t%%zmm3\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vmovdqa32\t\t16(%%r8)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // perform our 64 bit multiply, low side.
|
||||
"vpmadd231d\t%%zmm5,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"vmovdqa32\t\t32(%%r12)%{sint8%},\t%%zmm8\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vmovdqa32\t\t32(%%r8)%{uint8%},\t%%zmm1\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm8,\t%%zmm1,\t%%zmm2\n\t" // perform our 64 bit multiply, low side.
|
||||
"vpbroadcastd\t1(%%r9)%{uint8%},\t%%zmm6\n\t" // load the item we will be multiplying by.
|
||||
"vpmadd231d\t%%zmm2,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"vmovdqa32\t\t48(%%r12)%{sint8%},\t%%zmm3\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vmovdqa32\t\t48(%%r8)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm3,\t%%zmm4,\t%%zmm5\n\t" // perform our 64 bit multiply, low side.
|
||||
"vpmadd231d\t%%zmm5,\t%%zmm6,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"je\t2f\n\t" // if this is the last time through our loop, jump to 2.
|
||||
"vprefetche0\t64(%%r12)\n\t" // otherwise, prepare for another run-through.
|
||||
"vprefetche0\t64(%%r8)\n\t"
|
||||
"vprefetche2\t128(%%r12)\n\t"
|
||||
"vprefetche2\t128(%%r8)\n\t"
|
||||
"add\t$64,\t%%r12\n\t"
|
||||
"add\t$64,\t%%r8\n\t"
|
||||
"add\t$2,\t%%r9\n\t"
|
||||
"jmp\t1b\n\t"
|
||||
"2:\n\t"
|
||||
"vmovdqa32\t\t%%zmm7,\t(%[RES])\n\t" // save the result.
|
||||
"inc\t%%ecx\n\t" // we are in our loop, increment our counter.
|
||||
"vloadunpackld\t\t(%%r11)%{sint8%},\t%%zmm8\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vloadunpackld\t\t16(%%r11)%{sint8%},\t%%zmm9\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vloadunpackld\t\t32(%%r11)%{sint8%},\t%%zmm10\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vloadunpackld\t\t48(%%r11)%{sint8%},\t%%zmm11\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"vprefetch1\t128(%%r11)\n\t" // prepare for a run-through.
|
||||
"add\t$64,\t%%r11\n\t"
|
||||
"vloadunpackhd\t\t(%%r12)%{sint8%},\t%%zmm8\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"add\t$64,\t%%r12\n\t"
|
||||
"vloadunpackhd\t\t16(%%r13)%{sint8%},\t%%zmm9\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"add\t$64,\t%%r13\n\t"
|
||||
"vloadunpackhd\t\t32(%%r14)%{sint8%},\t%%zmm10\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"add\t$64,\t%%r14\n\t"
|
||||
"vloadunpackhd\t\t48(%%r15)%{sint8%},\t%%zmm11\n\t" // load the item we will be multiplying from. upscale it from int8 to int32.
|
||||
"add\t$64,\t%%r15\n\t"
|
||||
"vmovdqa32\t\t(%%r8)%{uint8%},\t%%zmm12\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm8,\t%%zmm12,\t%%zmm13\n\t" // perform our 64 bit multiply, low side.
|
||||
"vmovdqa32\t\t16(%%r8)%{uint8%},\t%%zmm14\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm9,\t%%zmm14,\t%%zmm15\n\t" // perform our 64 bit multiply, low side.
|
||||
"vmovdqa32\t\t32(%%r8)%{uint8%},\t%%zmm0\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm10,\t%%zmm0,\t%%zmm1\n\t" // perform our 64 bit multiply, low side.
|
||||
"vmovdqa32\t\t48(%%r8)%{uint8%},\t%%zmm2\n\t" // load the item we will be multiplying with. upscale it from int8 to int32.
|
||||
"vpmulld\t%%zmm11,\t%%zmm2,\t%%zmm3\n\t" // perform our 64 bit multiply, low side.
|
||||
"vprefetch1\t64(%%r8)\n\t" // prepare for a run-through.
|
||||
"add\t$64,\t%%r8\n\t"
|
||||
"vpbroadcastd\t(%%r9)%{uint8%},\t%%zmm4\n\t" // load the item we will be multiplying by.
|
||||
"vpbroadcastd\t1(%%r9)%{uint8%},\t%%zmm8\n\t" // load the item we will be multiplying by.
|
||||
"vprefetch1\t2(%%r9)\n\t"
|
||||
"add\t$2,\t%%r9\n\t"
|
||||
"vprefetch0\t(%%r11)\n\t" // prepare for a run-through.
|
||||
"vprefetch0\t64(%%r11)\n\t" // prepare for a run-through.
|
||||
"vprefetch0\t(%%r8)\n\t" // prepare for a run-through.
|
||||
"vprefetch0\t(%%r9)\n\t" // prepare for a run-through.
|
||||
"cmp\t$4,\t%%ecx\n\t" // see if this is our last run-through.
|
||||
"vpmadd231d\t%%zmm13,\t%%zmm4,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"vpmadd231d\t%%zmm15,\t%%zmm4,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"vpmadd231d\t%%zmm1,\t%%zmm8,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"vpmadd231d\t%%zmm3,\t%%zmm8,\t%%zmm7\n\t" // perform our multiply-add.
|
||||
"jl\t1b\n\t"
|
||||
"vcvtfxpntdq2ps\t$0,%%zmm7,\t%%zmm9\n\t" // convert our ints to floats.
|
||||
"vfmadd231ps\t%%zmm5,\t%%zmm9,\t%%zmm6\n\t" // Perform a fused multiply add.
|
||||
"vmovaps\t\t%%zmm6,\t(%[RES])\n\t" // save the result.
|
||||
: [RES] "+r" (res)
|
||||
: [SRC11] "r" (src11),
|
||||
[SRC21] "r" (src21),
|
||||
[SCALE] "r" (scale),
|
||||
: [SRC8] "r" (q8),
|
||||
[OFFSET] "m" (q8offset),
|
||||
[SRC5] "r" (q5),
|
||||
[SCALE] "r" (scale),
|
||||
[SCALEX] "m" (scaleX),
|
||||
[SCALEY] "m" (scaleY),
|
||||
[Z] "m" (zero)
|
||||
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "r8", "r9", "r12", "memory");
|
||||
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "cc", "ecx", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "memory");
|
||||
}
|
||||
|
||||
// Unpack 256 unsigned 5 bit values into an 8 bit vector.
|
||||
inline static void GGML_5bit_Unpack (const uint8x16_t * q4, const uint8_t * q1, uint8x16_t * dst)
|
||||
// Handles q4 not being aligned correctly.
|
||||
// Requires dst to be aligned.
|
||||
inline static void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint8x16_t * dst)
|
||||
{
|
||||
uint8_t lowmask = 0x0F;
|
||||
uint32_t allmask=0xFFFFFFFF;
|
||||
uint8_t m=1;
|
||||
uint8_t bit5 = 0x10;
|
||||
|
||||
__asm__ __volatile__ (
|
||||
"vprefetche0\t(%[SRC1])\n\t" // Issue our memory requests first thing.
|
||||
"vprefetche0\t(%[SRC4])\n\t"
|
||||
"vprefetche1\t64(%[SRC4])\n\t"
|
||||
"mov\t%[SRC4],\t%%r12\n\t" // load the address of the head of our 4-bit list.
|
||||
"mov\t%[DST],\t%%r8\n\t" // load the address of the head of our destination list.
|
||||
"mov\t$0,%%ecx\n\t" // initialize our counter.
|
||||
"vmovdqa32\t(%[SRC1])%{uint8%},\t%%zmm6\n\t" // move 16 packed sets of single bits into the lower 8 bits of zmm6.
|
||||
"vmovdqa32\t16(%[SRC1])%{uint8%},\t%%zmm7\n\t" // move the next 16 packed sets of single bits into the lower 8 bits of zmm7.
|
||||
"vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm2\n\t " // load our mask.
|
||||
"vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm9\n\t" // load the bit we want to add (conditionally).
|
||||
"vpbroadcastd\t%[M]%{uint8%},\t%%zmm8\n\t" // select which bit we want to test for.
|
||||
"vprefetch0\t(%[SRC1])\n\t" // Issue our memory requests first thing.
|
||||
"vprefetch0\t(%[SRC4])\n\t"
|
||||
"vprefetchenta\t(%[DST])\n\t"
|
||||
"mov\t%[SRC4],\t%%r9\n\t" // load the address of the head of our 4-bit list.
|
||||
"mov\t%[DST],\t%%r8\n\t" // load the address of the head of our destination list.
|
||||
"mov\t$0,%%ecx\n\t" // initialize our counter.
|
||||
"vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm0\n\t" // load our mask.
|
||||
"vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm1\n\t" // load the bit we want to add (conditionally).
|
||||
"vpbroadcastd\t%[M]%{uint8%},\t%%zmm2\n\t" // select which bit we want to test for.
|
||||
"vmovdqa32\t(%[SRC1])%{uint8%},\t%%zmm3\n\t" // load 16 sets of 8 bit packed single bits.
|
||||
"vmovdqa32\t16(%[SRC1])%{uint8%},\t%%zmm4\n\t" // load the next 16 sets of 8 bit packed single bits.
|
||||
|
||||
"1:\n\t"
|
||||
"inc\t%%ecx\n\t" // we are in the loop. increment the counter.
|
||||
"vptestmd\t%%zmm6,\t%%zmm8,\t%%k1\n\t" // perform our test.
|
||||
"vptestmd\t%%zmm7,\t%%zmm8,\t%%k2\n\t" // perform our test.
|
||||
"vmovdqa32\t\t(%%r12)%{uint8%},\t%%zmm0\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||
"vpandd\t%%zmm0,\t%%zmm2,\t%%zmm4\n\t" // apply a mask, storing the low four bits of vector zmm0 into zmm4.
|
||||
"vpaddd\t%%zmm4,%%zmm9,%%zmm4%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm4%{uint8%},\t(%%r8)\n\t" // save our result.
|
||||
"vmovdqa32\t\t16(%%r12)%{uint8%},\t%%zmm1\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||
"vpandd\t%%zmm1,\t%%zmm2,\t%%zmm5\n\t" // apply a mask, storing the next low four bits of vector zmm1 into zmm5.
|
||||
"vpaddd\t%%zmm5,%%zmm9,%%zmm5%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm5%{uint8%},\t16(%%r8)\n\t" // save our result.
|
||||
"inc\t%%ecx\n\t" // we are in the loop. increment the counter.
|
||||
|
||||
"vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // perform our test.
|
||||
"vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // perform our test.
|
||||
|
||||
"vloadunpackld\t\t(%%r9)%{uint8%},\t%%zmm5\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||
"vloadunpackhd\t\t16(%%r9)%{uint8%},\t%%zmm5\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||
"vpandd\t%%zmm0,\t%%zmm5,\t%%zmm6\n\t" // apply a mask, storing the low four bits of vector zmm5 into zmm6.
|
||||
"vpord\t%%zmm1,%%zmm6,%%zmm6%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm6%{uint8%},\t(%%r8)\n\t" // save our result.
|
||||
|
||||
"vloadunpackld\t\t(%%r9)%{uint8%},\t%%zmm7\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||
"vloadunpackhd\t\t16(%%r9)%{uint8%},\t%%zmm7\n\t" // load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||
"vprefetch1\t32(%%r9)\n\t" // pull the next set of 4 bit sequences into the L2 cache.
|
||||
"vpandd\t%%zmm0,\t%%zmm7,\t%%zmm8\n\t" // apply a mask, storing the next low four bits of vector zmm1 into zmm5.
|
||||
"vpaddd\t%%zmm1,%%zmm8,%%zmm8%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm8%{uint8%},\t16(%%r8)\n\t" // save our result.
|
||||
|
||||
"add\t$32,\t%%r8\n\t"
|
||||
"cmp\t$4,\t%%ecx\n\t"
|
||||
"vpslld\t$1,\t%%zmm8,\t%%zmm8\n\t" // select which bit we want to test for.
|
||||
"vptestmd\t%%zmm6,\t%%zmm8,\t%%k1\n\t" // perform our test.
|
||||
"vptestmd\t%%zmm7,\t%%zmm8,\t%%k2\n\t" // perform our test.
|
||||
"vpsrld\t$4,\t%%zmm0,\t%%zmm4\n\t" // load our even 4 bit sequence into zmm4.
|
||||
"vpaddd\t%%zmm4,%%zmm9,%%zmm4%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm4%{uint8%},\t(%%r8)\n\t" // save our result.
|
||||
"vpsrld\t$4,\t%%zmm1,\t%%zmm5\n\t" // load our even 4 bit sequence into zmm5.
|
||||
"vpaddd\t%%zmm5,%%zmm9,%%zmm5%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm5%{uint8%},\t16(%%r8)\n\t" // save our result.
|
||||
|
||||
"vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // select which bit we want to test for.
|
||||
|
||||
"vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // perform our test.
|
||||
"vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // perform our test.
|
||||
"vpsrld\t$4,\t%%zmm5,\t%%zmm6\n\t" // load our even 4 bit sequence
|
||||
"vpsrld\t$4,\t%%zmm7,\t%%zmm8\n\t" // load our even 4 bit sequence
|
||||
"vpord\t%%zmm1,%%zmm6,%%zmm6%{%%k1%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vpord\t%%zmm1,%%zmm8,%%zmm8%{%%k2%}\n\t" // turn on bit 5 for all values that passed the prior test.
|
||||
"vmovdqa32\t\t%%zmm6%{uint8%},\t(%%r8)\n\t" // save our result.
|
||||
"vmovdqa32\t\t%%zmm8%{uint8%},\t16(%%r8)\n\t" // save our result.
|
||||
"vprefetchenta\t32(%%r8)\n\t"
|
||||
|
||||
"je\t2f\n\t"
|
||||
"vpslld\t$1,\t%%zmm8,\t%%zmm8\n\t" // select which bit we want to test for.
|
||||
"add\t$32,\t%%r12\n\t"
|
||||
|
||||
"vprefetch0\t32(%%r9)\n\t"
|
||||
"vprefetch1\t96(%%r9)\n\t"
|
||||
"vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // select which bit we want to test for.
|
||||
"add\t$32,\t%%r9\n\t"
|
||||
"add\t$32,\t%%r8\n\t"
|
||||
"jmp\t1b\n\t"
|
||||
"2:"
|
||||
|
@ -152,9 +244,8 @@ inline static void GGML_5bit_Unpack (const uint8x16_t * q4, const uint8_t * q1,
|
|||
[SRC1] "r" (q1),
|
||||
[MASK] "m" (lowmask),
|
||||
[M] "m" (m),
|
||||
[ALL] "m" (allmask),
|
||||
[BIT5] "m" (bit5)
|
||||
: "zmm0", "zmm1", "zmm2", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "cc", "ecx", "k1", "k2", "r12", "r8", "memory"
|
||||
: "zmm0", "zmm1", "zmm2", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "k1", "k2", "r12", "r8", "memory"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -185,19 +276,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
int8x16_t q8copy [QK_K];
|
||||
int32x16_t aux32;
|
||||
uint8x16_t q4copyvec [QK_K/32];
|
||||
uint8x16_t aux8 [QK_K/16];
|
||||
|
||||
// Fill in our 8 bit vector from y[]. required, because there is no good way to align members of y[], And I haven't mastered unaligned assembly yet...
|
||||
memcpy (q8copy, y[i].qs, QK_K);
|
||||
uint8x16_t q5 [QK_K/16];
|
||||
|
||||
// Fill in our 4 bit vector from x[]. required, because there is no good way to align members of x[], And I haven't mastered unaligned assembly yet...
|
||||
memcpy (q4copyvec, x[i].qs, QK_K/2);
|
||||
|
||||
// combine our 4 and 1 bit vector sets into an 8 bit value.
|
||||
GGML_5bit_Unpack(q4copyvec, x[i].qh, aux8);
|
||||
// combine our 4 and 1 bit vector sets into a 5 bit vector (in 8 bits).
|
||||
GGML_5bit_Unpack((const uint8x16_t *)x[i].qs, x[i].qh, q5);
|
||||
|
||||
// extract scales and mins..
|
||||
memcpy(utmp, x[i].scales, 12);
|
||||
|
@ -207,14 +290,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
utmp[2] = uaux;
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
// FIXME: while comparing FMA output to the original output, the original had an error. hunt it down.
|
||||
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16(q8copy, aux8, scales, &aux32);
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
||||
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
for (int l = 0; l < GGML_F32_EPR; ++l) ((float *)&sums)[l] += d * ((int32_t *)&aux32)[l];
|
||||
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
||||
|
||||
// FIXME: while comparing FMA output to the original output, the original had an error. hunt it down.
|
||||
GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned((const int8x16_t *)y[i].qs, q5, scales, x[i].d, y[i].d, &sums);
|
||||
|
||||
const float dmin = GGML_PHI_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
||||
sumf -= dmin * sumi;
|
||||
}
|
||||
|
||||
|
|
|
@ -8,8 +8,24 @@ extern "C"
|
|||
{
|
||||
#endif
|
||||
|
||||
/* A forward declaration, to keep GCC happy. */
|
||||
/* A forward declaration, to keep GCC happy. */
|
||||
void ggml_vec_dot_q5_K_q8_K(int n, float *restrict s, size_t bs, const void *restrict vx, size_t bx, const void *restrict vy, size_t by, int nrc);
|
||||
// Force an alignment onto these vectors.
|
||||
typedef float float32x16_t __attribute__((vector_size (64), aligned(64)));
|
||||
typedef int8_t int8x16_t __attribute__((vector_size (16), aligned(16)));
|
||||
typedef uint8_t uint8x16_t __attribute__((vector_size (16), aligned(16)));
|
||||
typedef int32_t int32x16_t __attribute__((vector_size (64), aligned(64)));
|
||||
|
||||
// Zero out a vector of Floats
|
||||
void GGML_F32x16_VEC_ZERO(float32x16_t *target);
|
||||
// Convert an FP16 value to FP32(Float).
|
||||
float GGML_PHI_FP16_TO_FP32(ggml_fp16_t src);
|
||||
// Convert an FP32 value to FP16.
|
||||
ggml_fp16_t GGML_PHI_FP32_TO_FP16(float src);
|
||||
// Create a 5 bit int vector from a 4 bit vector and a 1 bit vector, both in packed forms.
|
||||
void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint8x16_t * dst);
|
||||
// Multiply a Q5 and Q8 vector against each other, with some scaling.
|
||||
void GGML_8X_2xI8x16_2xI8x16_MUL_2xI16x16_S_FMA_I32x16_Unaligned (const int8x16_t *q8, uint8x16_t *q5, const uint8_t *scale, ggml_fp16_t scaleX, float scaleY, float32x16_t *res);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue