Starting to add k-quantization to ggml
I think it is better to have quantization separate from ggml. For now just adding the k-quants there, but it would be better to also factor out the existing ggml quantizations.
This commit is contained in:
parent
136476e898
commit
8673a41385
4 changed files with 277 additions and 5 deletions
|
@ -370,6 +370,8 @@ endif()
|
||||||
add_library(ggml OBJECT
|
add_library(ggml OBJECT
|
||||||
ggml.c
|
ggml.c
|
||||||
ggml.h
|
ggml.h
|
||||||
|
k_quants.h
|
||||||
|
k_quants.c
|
||||||
${GGML_CUDA_SOURCES}
|
${GGML_CUDA_SOURCES}
|
||||||
${GGML_OPENCL_SOURCES})
|
${GGML_OPENCL_SOURCES})
|
||||||
|
|
||||||
|
|
13
Makefile
13
Makefile
|
@ -210,6 +210,9 @@ $(info )
|
||||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||||
$(CC) $(CFLAGS) -c $< -o $@
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
k_quants.o: k_quants.c k_quants.h ggml.h ggml-cuda.h
|
||||||
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h
|
llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
@ -232,19 +235,19 @@ main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
@echo '==== Run ./main -h for help. ===='
|
@echo '==== Run ./main -h for help. ===='
|
||||||
@echo
|
@echo
|
||||||
|
|
||||||
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
|
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o k_quants.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o $(OBJS)
|
quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o k_quants.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o k_quants.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o k_quants.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o k_quants.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
|
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
|
|
191
k_quants.c
Normal file
191
k_quants.c
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
#include "k_quants.h"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#undef MIN
|
||||||
|
#undef MAX
|
||||||
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===================== Helper functions
|
||||||
|
//
|
||||||
|
static inline int nearest_int(float fval) {
|
||||||
|
assert(fval <= 4194303.f);
|
||||||
|
float val = fval + 12582912.f;
|
||||||
|
int i; memcpy(&i, &val, sizeof(int));
|
||||||
|
return (i & 0x007fffff) - 0x00400000;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
|
||||||
|
float max = 0;
|
||||||
|
float amax = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float ax = fabsf(x[i]);
|
||||||
|
if (ax > amax) { amax = ax; max = x[i]; }
|
||||||
|
}
|
||||||
|
if (!amax) { // all zero
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
L[i] = 0;
|
||||||
|
}
|
||||||
|
return 0.f;
|
||||||
|
}
|
||||||
|
float iscale = -nmax / max;
|
||||||
|
if (rmse_type == 0) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale * x[i]);
|
||||||
|
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
||||||
|
}
|
||||||
|
return 1/iscale;
|
||||||
|
}
|
||||||
|
int weight_type = rmse_type%2;
|
||||||
|
float sumlx = 0;
|
||||||
|
float suml2 = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale * x[i]);
|
||||||
|
l = MAX(-nmax, MIN(nmax-1, l));
|
||||||
|
L[i] = l + nmax;
|
||||||
|
float w = weight_type == 1 ? x[i] * x[i] : 1;
|
||||||
|
sumlx += w*x[i]*l;
|
||||||
|
suml2 += w*l*l;
|
||||||
|
}
|
||||||
|
float scale = sumlx/suml2;
|
||||||
|
float best = scale * sumlx;
|
||||||
|
for (int itry = 0; itry < 3; ++itry) {
|
||||||
|
iscale = 1/scale;
|
||||||
|
float slx = 0;
|
||||||
|
float sl2 = 0;
|
||||||
|
bool changed = false;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale * x[i]);
|
||||||
|
l = MAX(-nmax, MIN(nmax-1, l));
|
||||||
|
if (l + nmax != L[i]) { changed = true; }
|
||||||
|
float w = weight_type == 1 ? x[i] * x[i] : 1.f;
|
||||||
|
slx += w*x[i]*l;
|
||||||
|
sl2 += w*l*l;
|
||||||
|
}
|
||||||
|
if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale * x[i]);
|
||||||
|
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
||||||
|
}
|
||||||
|
sumlx = slx; suml2 = sl2;
|
||||||
|
scale = sumlx/suml2;
|
||||||
|
best = scale * sumlx;
|
||||||
|
}
|
||||||
|
for (int itry = 0; itry < 5; ++itry) {
|
||||||
|
int n_changed = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float w = weight_type == 1 ? x[i]*x[i] : 1;
|
||||||
|
int l = L[i] - nmax;
|
||||||
|
float slx = sumlx - w*x[i]*l;
|
||||||
|
if (slx > 0) {
|
||||||
|
float sl2 = suml2 - w*l*l;
|
||||||
|
int new_l = nearest_int(x[i] * sl2 / slx);
|
||||||
|
new_l = MAX(-nmax, MIN(nmax-1, new_l));
|
||||||
|
if (new_l != l) {
|
||||||
|
slx += w*x[i]*new_l;
|
||||||
|
sl2 += w*new_l*new_l;
|
||||||
|
if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
|
||||||
|
L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
|
||||||
|
scale = sumlx / suml2; best = scale * sumlx;
|
||||||
|
++n_changed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!n_changed) { break; }
|
||||||
|
}
|
||||||
|
if (rmse_type < 3) {
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
for (int is = -4; is <= 4; ++is) {
|
||||||
|
if (is == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
iscale = -(nmax + 0.1f*is) / max;
|
||||||
|
sumlx = suml2 = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale * x[i]);
|
||||||
|
l = MAX(-nmax, MIN(nmax-1, l));
|
||||||
|
float w = weight_type == 1 ? x[i] * x[i] : 1;
|
||||||
|
sumlx += w*x[i]*l;
|
||||||
|
suml2 += w*l*l;
|
||||||
|
}
|
||||||
|
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale * x[i]);
|
||||||
|
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
||||||
|
}
|
||||||
|
scale = sumlx/suml2; best = scale*sumlx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
int8_t L[QK_K];
|
||||||
|
float scales[QK_K/16];
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
|
float max_scale = 0;
|
||||||
|
float max_abs_scale = 0;
|
||||||
|
|
||||||
|
for (int ib = 0; ib < QK_K/16; ++ib) {
|
||||||
|
|
||||||
|
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
|
||||||
|
scales[ib] = scale;
|
||||||
|
|
||||||
|
const float abs_scale = fabsf(scale);
|
||||||
|
if (abs_scale > max_abs_scale) {
|
||||||
|
max_abs_scale = abs_scale;
|
||||||
|
max_scale = scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
float iscale = -128.f/max_scale;
|
||||||
|
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
||||||
|
for (int ib = 0; ib < QK_K/16; ++ib) {
|
||||||
|
y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
|
float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
|
||||||
|
if (!d) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int ii = 0; ii < 16; ++ii) {
|
||||||
|
int l = nearest_int(x[16*j + ii]/d);
|
||||||
|
l = MAX(-32, MIN(31, l));
|
||||||
|
L[16*j + ii] = l + 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t * restrict ql = y[i].ql;
|
||||||
|
uint8_t * restrict qh = y[i].qh;
|
||||||
|
for (int j = 0; j < QK_K; j += 128) {
|
||||||
|
for (int l = 0; l < 32; ++l) {
|
||||||
|
const uint8_t q1 = L[j + l + 0] & 0xF;
|
||||||
|
const uint8_t q2 = L[j + l + 32] & 0xF;
|
||||||
|
const uint8_t q3 = L[j + l + 64] & 0xF;
|
||||||
|
const uint8_t q4 = L[j + l + 96] & 0xF;
|
||||||
|
ql[l+ 0] = q1 | (q3 << 4);
|
||||||
|
ql[l+32] = q2 | (q4 << 4);
|
||||||
|
qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
|
||||||
|
}
|
||||||
|
ql += 64;
|
||||||
|
qh += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += QK_K;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
76
k_quants.h
Normal file
76
k_quants.h
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
//
|
||||||
|
// 3-6 bit quantization in super-blocks
|
||||||
|
//
|
||||||
|
|
||||||
|
// Super-block size
|
||||||
|
#define QK_K 256
|
||||||
|
|
||||||
|
// 3-bit quantization
|
||||||
|
// weight is represented as x = a * q
|
||||||
|
// 16 blocks of 16 elemenets each
|
||||||
|
// Effectively 3.4375 bits per weight
|
||||||
|
typedef struct {
|
||||||
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||||
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||||
|
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||||
|
ggml_fp16_t d; // super-block scale
|
||||||
|
} block_q3_K;
|
||||||
|
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
||||||
|
|
||||||
|
// 4-bit quantization
|
||||||
|
// 16 blocks of 32 elements each
|
||||||
|
// weight is represented as x = a * q + b
|
||||||
|
// Effectively 4.5 bits per weight
|
||||||
|
typedef struct {
|
||||||
|
ggml_fp16_t d; // super-block scale for quantized scales
|
||||||
|
ggml_fp16_t dmin; // super-block scale for quantized mins
|
||||||
|
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||||
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
|
} block_q4_K;
|
||||||
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
||||||
|
|
||||||
|
// 5-bit quantization
|
||||||
|
// 16 blocks of 32 elements each
|
||||||
|
// weight is represented as x = a * q + b
|
||||||
|
// Effectively 5.5 bits per weight
|
||||||
|
typedef struct {
|
||||||
|
ggml_fp16_t d; // super-block scale for quantized scales
|
||||||
|
ggml_fp16_t dmin; // super-block scale for quantized mins
|
||||||
|
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||||
|
uint8_t qh[QK_K/8]; // quants, high bit
|
||||||
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||||
|
} block_q5_K;
|
||||||
|
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
||||||
|
|
||||||
|
// 6-bit quantization
|
||||||
|
// weight is represented as x = a * q
|
||||||
|
// 16 blocks of 16 elemenets each
|
||||||
|
// Effectively 6.5625 bits per weight
|
||||||
|
typedef struct {
|
||||||
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||||
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||||
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||||
|
ggml_fp16_t d; // super-block scale
|
||||||
|
} block_q6_K;
|
||||||
|
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
|
||||||
|
|
||||||
|
// This is only used for intermediate quantization and dot products
|
||||||
|
typedef struct {
|
||||||
|
float d; // delta
|
||||||
|
int8_t qs[QK_K]; // quants
|
||||||
|
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
|
||||||
|
} block_q8_K;
|
||||||
|
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
||||||
|
|
||||||
|
|
||||||
|
// Quantization
|
||||||
|
static void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue