Support all LLaMA models + change Q4_0 quantization storage

This commit is contained in:
Georgi Gerganov 2023-03-11 10:47:09 +02:00
parent 5f2f970d51
commit 007a8f6f45
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 404 additions and 205 deletions

97
ggml.c
View file

@ -366,9 +366,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
float * restrict pd = (float *) (y);
uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t * restrict pd = (uint8_t *) (y + 0*bs);
uint8_t * restrict pb = (uint8_t *) (y + 0*bs + sizeof(float));
uint8_t pp[QK/2];
@ -395,7 +396,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0;
pd[i] = d;
*(float *)pd = d;
pd += bs;
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
@ -406,7 +408,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
memcpy(pb + i*16, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
#else
#error "not implemented for QK"
@ -434,7 +437,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0;
pd[i] = d;
*(float *)pd = d;
pd += bs;
for (int l = 0; l < 8; l++) {
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
@ -445,7 +449,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
}
memcpy(pb + i*16, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
#else
#error "not implemented for QK"
@ -463,7 +468,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;
pd[i] = d;
*(float *)pd = d;
pd += bs;
for (int l = 0; l < QK; l += 2) {
const float v0 = x[i*QK + l + 0]*id;
@ -478,7 +484,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[l/2] = vi0 | (vi1 << 4);
}
memcpy(pb + i*QK/2, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
#endif
}
@ -535,15 +542,16 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs);
const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float));
// scalar
for (int i = 0; i < nb; i++) {
const float d = pd[i];
const float d = *(const float *) (pd + i*bs);
const uint8_t * restrict pp = pb + i*QK/2;
const uint8_t * restrict pp = pb + i*bs;
for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2];
@ -554,6 +562,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
y[i*QK + l + 0] = v0;
y[i*QK + l + 1] = v1;
@ -1179,11 +1189,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
assert(n % QK == 0);
assert(nb % 2 == 0);
const float * restrict pd0 = (const float *) x;
const float * restrict pd1 = (const float *) y;
const size_t bs = sizeof(float) + QK/2;
const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
const uint8_t * restrict pd0 = (const uint8_t *) (x + 0*bs);
const uint8_t * restrict pd1 = (const uint8_t *) (y + 0*bs);
const uint8_t * restrict pb0 = (const uint8_t *) (x + 0*bs + sizeof(float));
const uint8_t * restrict pb1 = (const uint8_t *) (y + 0*bs + sizeof(float));
float sumf = 0.0;
@ -1193,23 +1205,23 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d1_0 = pd1[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_1 = pd1[i + 1];
const float d0_0 = *(const float *) (pd0 + i*bs);
const float d1_0 = *(const float *) (pd1 + i*bs);
const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
const uint8x16_t v0_1 = vld1q_u8(p0 + bs);
const uint8x16_t v1_1 = vld1q_u8(p1 + bs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
@ -1280,21 +1292,21 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_0 = pd1[i + 0];
const float d1_1 = pd1[i + 1];
const float d0_0 = *(const float *) (pd0 + i*bs);
const float d1_0 = *(const float *) (pd1 + i*bs);
const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const v128_t m4b = wasm_u8x16_splat(0xf);
const v128_t s8b = wasm_i8x16_splat(0x8);
const v128_t v0_0 = wasm_v128_load(p0);
const v128_t v0_1 = wasm_v128_load(p0 + 16);
const v128_t v0_1 = wasm_v128_load(p0 + bs);
const v128_t v1_0 = wasm_v128_load(p1);
const v128_t v1_1 = wasm_v128_load(p1 + 16);
const v128_t v1_1 = wasm_v128_load(p1 + bs);
// 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@ -1363,11 +1375,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const float d0 = *(const float *) (pd0 + i*bs);
const float d1 = *(const float *) (pd1 + i*bs);
const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j];
@ -1552,16 +1564,17 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
assert(n % QK == 0);
const int nb = n / QK;
const size_t bs = sizeof(float) + QK/2;
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs);
const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float));
#if __ARM_NEON
#if QK == 32
for (int i = 0; i < nb; ++i) {
const float d0 = pd[i]*v;
const float d0 = v*(*(const float *) (pd + i*bs));
const uint8_t * restrict pp = pb + i*16;
const uint8_t * restrict pp = pb + i*bs;
const uint8x8_t m4b = vdup_n_u8(0xf);
const int8x8_t s8b = vdup_n_s8(0x8);
@ -1615,9 +1628,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d = pd[i];
const float d = *(const float *) (pd + i*bs);
const uint8_t * restrict pp = pb + i*QK/2;
const uint8_t * restrict pp = pb + i*bs;
for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2];