ggml : add Q8_0 quantization for intermediate results

This commit is contained in:
Georgi Gerganov 2023-04-13 23:03:27 +03:00
parent aa485cee33
commit 3b894ec657
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 271 additions and 16 deletions

284
ggml.c
View file

@ -584,12 +584,19 @@ static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block si
// blocks of QK elements // blocks of QK elements
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
typedef struct { typedef struct {
float d; float d; // delta
float m; float m; // min
uint8_t qs[QK / 2]; // nibbles / quants uint8_t qs[QK / 2]; // nibbles / quants
} block_q4_1; } block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
typedef struct {
float d; // delta
uint8_t qs[QK]; // nibbles / quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding");
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
assert(k % QK == 0); assert(k % QK == 0);
@ -1042,6 +1049,76 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
#endif #endif
} }
// reference implementation for deterministic creation of model files
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int l = 0; l < QK; l++) {
const float v = x[i*QK + l];
amax = MAX(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
for (int l = 0; l < QK; ++l) {
const float v = x[i*QK + l]*id;
const uint8_t vi = (int8_t)roundf(v) + 128;
y[i].qs[l] = vi;
}
}
}
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
assert(k % QK == 0);
const int nb = k / QK;
block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(128.5f));
const int32x4_t vi = vcvtq_s32_f32(vf);
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
}
}
#else
// scalar
quantize_row_q8_0_reference(x, y, k);
#endif
}
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
assert(k % QK == 0); assert(k % QK == 0);
const int nb = k / QK; const int nb = k / QK;
@ -2344,12 +2421,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
sum00 += x0->m*y0->m; sum00 += x0->m*y0->m;
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
sum00 += x1->m*y1->m; sum00 += x1->m*y1->m;
sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h)); sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h)); sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t // dot product into int32x4_t
@ -2417,6 +2494,129 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
*s = sumf; *s = sumf;
} }
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK;
assert(n % QK == 0);
assert(nb % 2 == 0);
const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON)
float sum0 = 0.0f;
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t u128b = vdupq_n_u8(128);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
const uint8x16_t v1_0l = vld1q_u8(y0->qs);
const uint8x16_t v1_0h = vld1q_u8(y0->qs + 16);
const uint8x16_t v1_1l = vld1q_u8(y1->qs);
const uint8x16_t v1_1h = vld1q_u8(y1->qs + 16);
// interleave
const uint8x16_t v1_0lz = vuzp1q_u8(v1_0l, v1_0h);
const uint8x16_t v1_0hz = vuzp2q_u8(v1_0l, v1_0h);
const uint8x16_t v1_1lz = vuzp1q_u8(v1_1l, v1_1h);
const uint8x16_t v1_1hz = vuzp2q_u8(v1_1l, v1_1h);
const int8x16_t v1_0ls = vreinterpretq_s8_u8(vsubq_u8(v1_0lz, u128b));
const int8x16_t v1_0hs = vreinterpretq_s8_u8(vsubq_u8(v1_0hz, u128b));
const int8x16_t v1_1ls = vreinterpretq_s8_u8(vsubq_u8(v1_1lz, u128b));
const int8x16_t v1_1hs = vreinterpretq_s8_u8(vsubq_u8(v1_1hz, u128b));
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
#endif
}
sumf = sum0 + sum1;
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;
const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
int sumi = 0;
for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j];
const int i0 = (int8_t) (v0 & 0xf) - 8;
const int i1 = (int8_t) (v0 >> 4) - 8;
const int i2 = (int) p1[2*j + 0] - 128;
const int i3 = (int) p1[2*j + 1] - 128;
/*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
sumi += i0*i2 + i1*i3;
}
sumf += d0*d1*sumi;
}
#endif
*s = sumf;
}
// compute GGML_VEC_DOT_UNROLL dot products at once // compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes // xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@ -2663,22 +2863,24 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F16] = 1, [GGML_TYPE_F16] = 1,
[GGML_TYPE_Q4_0] = QK, [GGML_TYPE_Q4_0] = QK,
[GGML_TYPE_Q4_1] = QK, [GGML_TYPE_Q4_1] = QK,
[GGML_TYPE_Q8_0] = QK,
[GGML_TYPE_I8] = 1, [GGML_TYPE_I8] = 1,
[GGML_TYPE_I16] = 1, [GGML_TYPE_I16] = 1,
[GGML_TYPE_I32] = 1, [GGML_TYPE_I32] = 1,
}; };
static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated"); static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = sizeof(float), [GGML_TYPE_F32] = sizeof(float),
[GGML_TYPE_F16] = sizeof(ggml_fp16_t), [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
[GGML_TYPE_Q4_0] = sizeof(block_q4_0), [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
[GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
[GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I8] = sizeof(int8_t),
[GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I16] = sizeof(int16_t),
[GGML_TYPE_I32] = sizeof(int32_t), [GGML_TYPE_I32] = sizeof(int32_t),
}; };
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated"); static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@ -3371,6 +3573,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
{ {
assert(tensor->nb[0] == sizeof(int8_t)); assert(tensor->nb[0] == sizeof(int8_t));
@ -3431,6 +3637,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
{ {
assert(tensor->nb[0] == sizeof(int8_t)); assert(tensor->nb[0] == sizeof(int8_t));
@ -3485,6 +3695,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
{ {
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3529,6 +3743,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
{ {
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3571,6 +3789,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
{ {
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3615,6 +3837,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
{ {
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -5437,6 +5663,7 @@ static void ggml_compute_forward_dup(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5518,6 +5745,7 @@ static void ggml_compute_forward_add(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5570,6 +5798,7 @@ static void ggml_compute_forward_sub(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5622,6 +5851,7 @@ static void ggml_compute_forward_mul(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5674,6 +5904,7 @@ static void ggml_compute_forward_div(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5722,6 +5953,7 @@ static void ggml_compute_forward_sqr(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5770,6 +6002,7 @@ static void ggml_compute_forward_sqrt(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5828,6 +6061,7 @@ static void ggml_compute_forward_sum(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5905,6 +6139,7 @@ static void ggml_compute_forward_mean(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -5969,6 +6204,7 @@ static void ggml_compute_forward_repeat(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6017,6 +6253,7 @@ static void ggml_compute_forward_abs(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6065,6 +6302,7 @@ static void ggml_compute_forward_sgn(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6113,6 +6351,7 @@ static void ggml_compute_forward_neg(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6161,6 +6400,7 @@ static void ggml_compute_forward_step(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6209,6 +6449,7 @@ static void ggml_compute_forward_relu(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6274,6 +6515,7 @@ static void ggml_compute_forward_gelu(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6341,6 +6583,7 @@ static void ggml_compute_forward_silu(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6427,6 +6670,7 @@ static void ggml_compute_forward_norm(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6507,6 +6751,7 @@ static void ggml_compute_forward_rms_norm(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -6906,9 +7151,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = { [GGML_TYPE_Q4_0] = {
.dequantize_row_q = dequantize_row_q4_0, .dequantize_row_q = dequantize_row_q4_0,
.quantize_row_q = quantize_row_q4_0, .quantize_row_q = quantize_row_q8_0,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
.vec_dot_q = ggml_vec_dot_q4_0, .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
}, },
[GGML_TYPE_Q4_1] = { [GGML_TYPE_Q4_1] = {
.dequantize_row_q = dequantize_row_q4_1, .dequantize_row_q = dequantize_row_q4_1,
@ -6916,6 +7161,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
.vec_dot_q = ggml_vec_dot_q4_1, .vec_dot_q = ggml_vec_dot_q4_1,
}, },
// TODO: GGML_TYPE_Q8_0
}; };
// For internal test use // For internal test use
@ -7041,7 +7287,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
char * wdata = params->wdata; char * wdata = params->wdata;
const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type]; const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
@ -7072,7 +7318,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
void * wdata = params->wdata; void * wdata = params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type]; const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices // src0 indices
@ -7120,6 +7366,7 @@ static void ggml_compute_forward_mul_mat(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
{ {
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break; } break;
@ -7218,6 +7465,7 @@ static void ggml_compute_forward_scale(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -7383,6 +7631,7 @@ static void ggml_compute_forward_get_rows(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
{ {
ggml_compute_forward_get_rows_q(params, src0, src1, dst); ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break; } break;
@ -7472,6 +7721,7 @@ static void ggml_compute_forward_diag_mask_inf(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -7566,6 +7816,7 @@ static void ggml_compute_forward_soft_max(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -7749,6 +8000,7 @@ static void ggml_compute_forward_rope(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -8017,6 +8269,7 @@ static void ggml_compute_forward_conv_1d_1s(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -8285,6 +8538,7 @@ static void ggml_compute_forward_conv_1d_2s(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -8770,6 +9024,7 @@ static void ggml_compute_forward_flash_attn(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -8981,6 +9236,7 @@ static void ggml_compute_forward_flash_ff(
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -9913,9 +10169,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
} else } else
#endif #endif
{ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
}
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }

1
ggml.h
View file

@ -204,6 +204,7 @@ enum ggml_type {
GGML_TYPE_F16 = 1, GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_1 = 3,
GGML_TYPE_Q8_0 = 4,
GGML_TYPE_I8, GGML_TYPE_I8,
GGML_TYPE_I16, GGML_TYPE_I16,
GGML_TYPE_I32, GGML_TYPE_I32,