This commit is contained in:
chooper1 2023-09-18 20:17:06 -07:00
parent 7954f8defd
commit 80f69694e5
3 changed files with 16 additions and 15 deletions

3
ggml.c
View file

@ -1792,7 +1792,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.to_float = NULL, .to_float = NULL,
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.from_float_reference = NULL, .from_float_reference = NULL,
.vec_dot = ggml_vec_dot_q4_sq_fp16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_q4_sq_fp16,
.vec_dot_type = GGML_TYPE_F16, .vec_dot_type = GGML_TYPE_F16,
} }
#endif #endif
@ -12640,6 +12640,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K:
case GGML_TYPE_Q4_SQ:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:

26
sqllm.c
View file

@ -8,7 +8,7 @@
#include <stdlib.h> #include <stdlib.h>
void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, const void * restrict v, const ggml_fp16_t * restrict y) { void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, void * restrict v, ggml_fp16_t * restrict y) {
const int nb = n / 8; const int nb = n / 8;
@ -17,7 +17,7 @@ void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, const void * restr
// pointer initialization // pointer initialization
int32_t * baselut = v; int32_t * baselut = v;
int32_t * qweight = baselut + 8; // get start of row int32_t * qweight = baselut + 8; // get start of row
float * yvector = y; float * yvector = (void *) y;
// initialize sum // initialize sum
float16x8_t sumf1 = vdupq_n_f16(0); float16x8_t sumf1 = vdupq_n_f16(0);
@ -26,15 +26,15 @@ void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, const void * restr
float16x8_t sumf4 = vdupq_n_f16(0); float16x8_t sumf4 = vdupq_n_f16(0);
// initialize lookup table // initialize lookup table
uint8x16_t lut1 = vld1q_u8(baselut); uint8x16_t lut1 = vld1q_u8((void *) baselut);
uint8x16_t lut2 = vld1q_u8(baselut+4); uint8x16_t lut2 = vld1q_u8((void *) (baselut+4));
uint8x16_t lutl = vuzp1q_u8(lut1, lut2); uint8x16_t lutl = vuzp1q_u8(lut1, lut2);
uint8x16_t luth = vuzp2q_u8(lut1, lut2); uint8x16_t luth = vuzp2q_u8(lut1, lut2);
for (int i = 0; i < nb; i+=4) { for (int i = 0; i < nb; i+=4) {
// get packed vector // get packed vector
uint8x16_t m4b = vdupq_n_u8(0x0F); uint8x16_t m4b = vdupq_n_u8(0x0F);
uint8x16_t packed_vector = vld1q_u8(&qweight[i]); uint8x16_t packed_vector = vld1q_u8((void *) &qweight[i]);
// 4-bit -> 2 8-bit vectors // 4-bit -> 2 8-bit vectors
uint8x16_t packed_vector_lb = vandq_u8 (packed_vector, m4b); uint8x16_t packed_vector_lb = vandq_u8 (packed_vector, m4b);
@ -51,16 +51,16 @@ void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, const void * restr
uint8x16_t lookup_1h = vqtbl1q_u8 (luth, packed_vector_1); uint8x16_t lookup_1h = vqtbl1q_u8 (luth, packed_vector_1);
// interleave lookup values // interleave lookup values
float16x8_t lookup_0_z1 = vzip1q_u8(lookup_0l, lookup_0h); float16x8_t lookup_0_z1 = (float16x8_t) vzip1q_u8(lookup_0l, lookup_0h);
float16x8_t lookup_0_z2 = vzip2q_u8(lookup_0l, lookup_0h); float16x8_t lookup_0_z2 = (float16x8_t) vzip2q_u8(lookup_0l, lookup_0h);
float16x8_t lookup_1_z1 = vzip1q_u8(lookup_1l, lookup_1h); float16x8_t lookup_1_z1 = (float16x8_t) vzip1q_u8(lookup_1l, lookup_1h);
float16x8_t lookup_1_z2 = vzip2q_u8(lookup_1l, lookup_1h); float16x8_t lookup_1_z2 = (float16x8_t) vzip2q_u8(lookup_1l, lookup_1h);
//load int8 values //load int8 values
float16x8_t tmp1 = vld1q_f16(&yvector[4*i]); float16x8_t tmp1 = vld1q_f16(((void *) &yvector[4*i]));
float16x8_t tmp2 = vld1q_f16(&yvector[4*i+4]); float16x8_t tmp2 = vld1q_f16(((void *) &yvector[4*i+4]));
float16x8_t tmp3 = vld1q_f16(&yvector[4*i+8]); float16x8_t tmp3 = vld1q_f16(((void *) &yvector[4*i+8]));
float16x8_t tmp4 = vld1q_f16(&yvector[4*i+12]); float16x8_t tmp4 = vld1q_f16(((void *) &yvector[4*i+12]));
//fp16 mul //fp16 mul
sumf1 = vfmaq_f16(sumf1, lookup_0_z1, tmp1); sumf1 = vfmaq_f16(sumf1, lookup_0_z1, tmp1);

View file

@ -10,4 +10,4 @@
#ifdef __ARM_NEON #ifdef __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, const void * restrict v, const ggml_fp16_t * restrict y); void ggml_vec_dot_q4_sq_fp16(const int n, float * restrict s, void * restrict v, ggml_fp16_t * restrict y);