k_quants: WIP super-blocks with 64 weights

This commit is contained in:
Iwan Kawrakow 2023-06-21 12:43:44 +03:00
parent 447ccbe8c3
commit d2f12ac354
4 changed files with 270 additions and 24 deletions

View file

@ -75,6 +75,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@ -292,6 +293,9 @@ endif()
if (LLAMA_K_QUANTS)
set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
add_compile_definitions(GGML_USE_K_QUANTS)
if (LLAMA_QKK_64)
add_compile_definitions(GGML_QKK_64)
endif()
endif()
if (LLAMA_CLBLAST)

View file

@ -131,6 +131,10 @@ ifndef LLAMA_NO_K_QUANTS
CFLAGS += -DGGML_USE_K_QUANTS
CXXFLAGS += -DGGML_USE_K_QUANTS
OBJS += k_quants.o
ifdef LLAMA_QKK_64
CFLAGS += -DGGML_QKK_64
CXXFLAGS += -DGGML_QKK_64
endif
endif
ifndef LLAMA_NO_ACCELERATE

View file

@ -330,11 +330,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
}
}
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif
x += QK_K;
@ -352,6 +358,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
const uint8_t * q = x[i].qs;
#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
@ -370,7 +377,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
}
q += 32;
}
#else
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
for (int l = 0; l < 16; ++l) {
y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
}
y += QK_K;
#endif
}
}
@ -412,6 +431,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
}
}
#if QK_K == 256
memset(y[i].scales, 0, 12);
if (max_scale) {
float iscale = -32.f/max_scale;
@ -445,9 +465,36 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
L[16*j + ii] = l + 4;
}
}
#else
if (max_scale) {
float iscale = -128.f/max_scale;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(iscale*scales[j]);
l = MAX(-128, MIN(127, l));
y[i].scales[j] = l;
}
y[i].d = ggml_fp32_to_fp16(1/iscale);
} else {
for (int j = 0; j < QK_K/16; ++j) {
y[i].scales[j] = 0;
}
y[i].d = ggml_fp32_to_fp16(0.f);
}
for (int j = 0; j < QK_K/16; ++j) {
float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-4, MIN(3, l));
L[16*j + ii] = l + 4;
}
}
#endif
memset(y[i].hmask, 0, QK_K/8);
// We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
int m = 0;
uint8_t hm = 1;
for (int j = 0; j < QK_K; ++j) {
@ -459,19 +506,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
m = 0; hm <<= 1;
}
}
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif
x += QK_K;
}
}
#if QK_K == 256
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
assert(QK_K == 256);
const int nb = k / QK_K;
const uint32_t kmask1 = 0x03030303;
@ -519,6 +572,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
}
}
#else
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
assert(QK_K == 64);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const float d_all = ggml_fp16_to_fp32(x[i].d);
const uint8_t * restrict q = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
const float d1 = d_all * x[i].scales[0];
const float d2 = d_all * x[i].scales[1];
const float d3 = d_all * x[i].scales[2];
const float d4 = d_all * x[i].scales[3];
for (int l=0; l<8; ++l) {
uint8_t h = hm[l];
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
}
y += QK_K;
}
}
#endif
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
quantize_row_q3_K_reference(x, vy, k);
@ -544,11 +630,14 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
const int nb = k / QK_K;
uint8_t L[QK_K];
#if QK_K == 256
float mins[QK_K/32];
float scales[QK_K/32];
#endif
for (int i = 0; i < nb; i++) {
#if QK_K == 256
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
@ -594,9 +683,28 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
L[32*j + ii] = l;
}
}
#else
for (int j = 0; j < QK_K/32; ++j) {
float min;
float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5);
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale);
y[i].d[2*j+1] = ggml_fp32_to_fp16(min);
}
for (int j = 0; j < QK_K/32; ++j) {
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]);
if (!d) continue;
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]);
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(15, l));
L[32*j + ii] = l;
}
}
#endif
uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) {
for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
}
x += QK_K;
@ -610,11 +718,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
#if QK_K == 256
const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);
const uint8_t * q = x[i].qs;
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
@ -626,6 +736,15 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
#else
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]);
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]);
for (int l = 0; l < 32; ++l) {
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
y[l+32] = d2 * (q[l] >> 4) - m2;
}
y += QK_K;
#endif
}
}
@ -654,11 +773,15 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
const int nb = k / QK_K;
uint8_t L[QK_K];
#if QK_K == 256
float mins[QK_K/32];
float scales[QK_K/32];
#endif
for (int i = 0; i < nb; i++) {
#if QK_K == 256
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
@ -725,6 +848,42 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
m1 <<= 2; m2 <<= 2;
ql += 32;
}
#else
for (int j = 0; j < QK_K/32; ++j) {
float min;
float scale = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &min, 5);
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale);
y[i].d[2*j+1] = ggml_fp32_to_fp16(min);
}
for (int j = 0; j < QK_K/32; ++j) {
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]);
if (!d) continue;
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]);
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(31, l));
L[32*j + ii] = l;
}
}
uint8_t * restrict qh = y[i].qh;
uint8_t * restrict ql = y[i].qs;
memset(qh, 0, QK_K/8);
for (int j = 0; j < 32; ++j) {
int jm = j%8;
int is = j/8;
int l1 = L[j];
if (l1 > 15) {
l1 -= 16; qh[jm] |= (1 << is);
}
int l2 = L[j + 32];
if (l2 > 15) {
l2 -= 16; qh[jm] |= (1 << (4 + is));
}
ql[j] = l1 | (l2 << 4);
}
#endif
x += QK_K;
@ -737,12 +896,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
for (int i = 0; i < nb; i++) {
const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);
const uint8_t * ql = x[i].qs;
const uint8_t * qh = x[i].qh;
#if QK_K == 256
const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
@ -756,6 +917,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
#else
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]);
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]);
for (int l = 0; l < 8; ++l) {
y[l+ 0] = d1 * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - m1;
y[l+ 8] = d1 * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - m1;
y[l+16] = d1 * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - m1;
y[l+24] = d1 * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - m1;
y[l+32] = d2 * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - m2;
y[l+40] = d2 * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - m2;
y[l+48] = d2 * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - m2;
y[l+56] = d2 * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - m2;
}
y += QK_K;
#endif
}
}
@ -823,6 +999,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
uint8_t * restrict ql = y[i].ql;
uint8_t * restrict qh = y[i].qh;
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[j + l + 0] & 0xF;
@ -836,6 +1013,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
ql += 64;
qh += 32;
}
#else
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[l + 0] & 0xF;
const uint8_t q2 = L[l + 32] & 0xF;
ql[l] = q1 | (q2 << 4);
}
for (int l = 0; l < 16; ++l) {
qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
}
#endif
x += QK_K;
@ -854,6 +1041,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
#if QK_K == 256
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
@ -871,6 +1059,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
qh += 32;
sc += 8;
}
#else
for (int l = 0; l < 16; ++l) {
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l+ 0] = d * sc[0] * q1;
y[l+16] = d * sc[1] * q2;
y[l+32] = d * sc[2] * q3;
y[l+48] = d * sc[3] * q4;
}
y += 64;
#endif
}
}
@ -1611,18 +1812,23 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) {
#if QK_K == 256
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
#else
// TODO
const float d = 0; const float dmin = 0;
#endif
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
@ -1840,18 +2046,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
#if QK_K == 256
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
#else
// TODO
const float d = 0, dmin = 0;
#endif
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));

View file

@ -7,7 +7,13 @@
#include <stddef.h>
// Super-block size
#ifdef GGML_QKK_64
#define QK_K 64
#define K_SCALE_SIZE 4
#else
#define QK_K 256
#define K_SCALE_SIZE 12
#endif
//
// Super-block quantization structures
@ -32,35 +38,56 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
#ifdef GGML_QKK_64
int8_t scales[K_SCALE_SIZE];
#else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
// 4-bit quantization
// 16 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
#endif
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
// 5-bit quantization
// 16 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif
// 6-bit quantization
// weight is represented as x = a * q