This commit is contained in:
milka :) 2023-06-13 10:39:13 +01:00 committed by GitHub
commit 3655865e15
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 698 additions and 19 deletions

View file

@ -25,6 +25,7 @@ static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
{"q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S},
{"q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M},
{"q6_K", LLAMA_FTYPE_MOSTLY_Q6_K},
{"qx_0", LLAMA_FTYPE_MOSTLY_QX_0},
};
bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {

581
ggml.c
View file

@ -488,6 +488,44 @@ int64_t ggml_cycles_per_ms(void) {
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
//
// bit manipulation helpers
//
// writes "bit_count" bits of "data" at a "bit_offset" offset in "dst"
// only used for data <= 16bits; only useful to quantize_qx_0
inline static void write_bits(uint32_t * dst, uint32_t bit_offset, uint16_t data, uint16_t bit_count) {
const uint32_t chunk_size = (sizeof(uint32_t) * 8);
const uint32_t chunk_id = bit_offset / chunk_size;
dst = dst + chunk_id;
bit_offset %= (sizeof(uint32_t) * 8);
if (bit_offset + bit_count > chunk_size) {
// first fill the current chunk
uint16_t bitcount_1 = chunk_size - bit_offset;
uint32_t bitmask = ((1 << bitcount_1) - 1) << (bit_offset);
*dst &= ~bitmask;
*dst |= data << bit_offset;
// move onto the next chunk
data >>= bitcount_1;
bit_count -= bitcount_1;
bit_offset = 0;
dst += 1;
bitmask = ((1 << bit_count) - 1) << (bit_offset);
*dst &= ~bitmask;
*dst |= data << bit_offset;
} else {
uint32_t bitmask = ((1 << bit_count) - 1) << (bit_offset);
*dst &= ~bitmask;
*dst |= data << bit_offset;
}
}
//
// quantization
//
@ -835,6 +873,25 @@ typedef struct {
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
// max block size is 256 because some feed_forward tensors have a width of 11008 weights, which is not divisible by 512
#define QKX_0 256
// There is no byte-exact C struct to represent a QX_0 block, but a high-level representation of a block is:
// ggml_fp16_t delta;
// ggml_fp16_t min;
// uint8_t block_metadata;
// [bitstream of weights]
// Quantization parameters for QX_0 (used only when running ./quantize, irrelevant during inference)
// Quantization starts at QX_0_STARTING_QBITS bits, and then moves down to QX_0_START_OF_ATTEMPTED_QBITS
// and tries lower and lower bit precisions from there
// TODO maybe move these to commandline arguments...?
#define QX_0_STARTING_QBITS 4
#define QX_0_START_OF_ATTEMPTED_QBITS 2
// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
static const int qk = QK4_0;
@ -1530,6 +1587,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
}
}
static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@ -1627,6 +1685,16 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
},
#endif
// GGML_TYPE_QX_0's quantize/dequantize functions aren't the same as other quantization methods' functions
// so we need to supply NULL instead and use if statements in the places where they are actually used
[GGML_TYPE_QX_0] = {
.dequantize_row_q = (dequantize_row_q_t) NULL,
.quantize_row_q = NULL,
.quantize_row_q_reference = (quantize_row_q_t) NULL,
.quantize_row_q_dot = quantize_row_q8_0,
.vec_dot_q = ggml_vec_dot_qx_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
},
};
// For internal test use
@ -3122,6 +3190,197 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
#endif
}
__attribute__((optimize("unroll-loops")))
static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
uint32_t nb = n / QKX_0;
GGML_ASSERT(QKX_0 % QK8_0 == 0);
*s = 0;
const uint8_t * quant_row = (const uint8_t *) vx;
const block_q8_0 * restrict column = vy;
uint32_t column_i = 0; // current index in column
// row_data is a buffer which stores dequantized float values for a current block
float f32_row_data[QKX_0];
// __AVX2__ doesn't seem to actually make much of a difference,
// a lot of optimizing could possibly be done, including possibly using AVX2
// for dequantization...?
#if defined(__AVX2__)
__m256 rolling_sum = _mm256_setzero_ps();
#endif
float qvals[1 << 4];
for (uint32_t b = 0; b < nb; b++) {
float * row_ptr = f32_row_data;
const uint64_t * block_start = (const uint64_t *) quant_row;
const float min_value = GGML_FP16_TO_FP32(*((const uint16_t *) (block_start + (QKX_0 / 64))));
float mult_value = GGML_FP16_TO_FP32(*((const uint16_t *) (block_start + (QKX_0 / 64)) + 1));
const uint16_t * data_start = (const uint16_t *) (block_start + (QKX_0 / 64)) + 2;
const uint8_t qbits = *((const uint8_t *) data_start);
data_start = (const uint16_t*) ((const uint8_t*) data_start + 1);
quant_row = (const uint8_t * ) data_start;
// Any qbits are supported, but the size of qvals needs to be changed to 1 << max_expected_qbits.
// So if you have at most 7bit values, you can change qvals's declaration to qvals[1 << 7].
// Additionally, the "fp_chooser == 0" optimized branch only works if qbits is "3" or a power of 2,
// so feel free to disable it entirely and run the slower "else" statement which works for pretty much
// any qbit value.
GGML_ASSERT(qbits <= 4);
uint32_t offset = 0;
uint8_t data_offset = 0;
// Cache quantized values
for (int i = 0; i < (1 << qbits); i++) {
qvals[i] = min_value + mult_value * i;
}
// Parse in sub-blocks of 64 since they are managed by a single uint64_t which decides if a given weight
// is on 16bit or quantized. This means that we can do a fast fp16_indicator == 0 check (i.e. all weights are quantized)
// to speed up peformance
for (int subblock_i = 0; subblock_i < QKX_0 / 64; subblock_i++) {
uint64_t fp16_indicator = block_start[subblock_i];
// all weights are quantized in this section; ALSO this ONLY works when qbits is <= 4, since (qbits != 3) simply checks if qbits is a power of 2
if (fp16_indicator == 0) {
if (qbits == 3) {
// same principle as on the regular data_offset branch, but this time the qbits cross byte boundaries, so we need to manage it by hand
for (int i = 0; i < 5; i++) {
for (int k = 0; k < 11; k ++) {
// here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range
row_ptr[i * 11 + k] = qvals[((((const uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))];
}
data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice
data_offset += 1; // it's actually +33, but the "+32" is represented in data_start above, so the remainder is simply +1
}
for (int k = 0; k < 9; k ++) {
// here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range
row_ptr[55 + k] = qvals[((((const uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))];
}
data_start += 1;
data_offset += 9 * 3 - 16;
if (data_offset >= 16) {
data_start += 1;
data_offset -= 16;
}
} else if (data_offset == 0) {
// This only properly works for QBits = power of 2
const uint8_t data_block_size = 64;
// we can take a full 64bit block
const uint8_t weights_per_u64_data_block = data_block_size / qbits;
const uint8_t num_of_data_blocks_needed = 64 / weights_per_u64_data_block; // because we have 64 qbit-sized weights here
for (int i = 0; i < num_of_data_blocks_needed; i++) {
for (int k = 0; k < weights_per_u64_data_block; k ++) {
row_ptr[i * weights_per_u64_data_block + k] = qvals[(((const uint64_t *) data_start)[0] >> (k * qbits)) & ((1 << qbits) - 1)];
}
data_start += (data_block_size / 8) / sizeof(uint16_t);
}
} else {
// We are doing u32 instead of a simple u64, since data_offset may not be 0 and we need to account for that
const uint8_t data_block_size = 32;
const uint8_t weights_per_u32_data_block = data_block_size / qbits;
const uint8_t num_of_data_blocks_needed = 64 / weights_per_u32_data_block;
for (int i = 0; i < num_of_data_blocks_needed; i++) {
for (int k = 0; k < weights_per_u32_data_block; k ++) {
// here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range
row_ptr[i * weights_per_u32_data_block + k] = qvals[((((const uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))];
}
data_start += (data_block_size / 8) / sizeof(uint16_t);
}
}
offset += qbits * 64;
} else {
for (int i = 0; i < 64; i++) {
if (fp16_indicator & 1) {
// Current weight is fp16
offset += 16;
row_ptr[i] = GGML_FP16_TO_FP32((((const uint32_t *) data_start)[0] >> data_offset) & ((1 << 16) - 1));
data_start += 1;
} else {
// Current weight is quantized
offset += qbits;
row_ptr[i] = qvals[((((const uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1))];
data_offset += qbits;
if (data_offset >= 16) {
data_start += 1;
data_offset -= 16;
}
}
// Shift the fp16 indicator to the right, to move to the next weight
fp16_indicator >>= 1;
}
}
for (int jb = 0; jb < 64 / QK8_0; jb++) {
#if defined(__AVX2__)
__m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_i].d));
for (int i = 0; i < QK8_0/8; i++) {
__m128i test = _mm_loadu_si128((const __m128i *) (column[column_i].qs + i * 8));
__m256i work = _mm256_cvtepi8_epi32(test);
__m256 workf = _mm256_cvtepi32_ps(work);
// multiply with our 8 parts of the row at row_data
__m256 row = _mm256_loadu_ps(row_ptr + jb * QK8_0 + i * 8);
workf = _mm256_mul_ps(workf, row);
rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum);
}
#else
// scalar
float sub_sum = 0;
for (int i = 0; i < QK8_0; i++) {
sub_sum += row_ptr[jb * QK8_0 + i] * column[column_i].qs[i];
}
sub_sum *= GGML_FP16_TO_FP32(column[column_i].d);
*s += sub_sum;
#endif
column_i += 1;
}
row_ptr += 64;
}
GGML_ASSERT(offset % 8 == 0);
quant_row += offset / 8;
}
#if defined(__AVX2__)
float rolling_sum_vec[8];
_mm256_store_ps(rolling_sum_vec, rolling_sum);
for (int i = 0; i < 8; i++) {
*s += rolling_sum_vec[i];
}
#endif
}
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_0;
const int nb = n / qk;
@ -3514,11 +3773,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q6_K] = QK_K,
[GGML_TYPE_Q8_K] = QK_K,
#endif
// [GGML_TYPE_QX_0], // QX_0 doesn't have a fixed block size
[GGML_TYPE_I8] = 1,
[GGML_TYPE_I16] = 1,
[GGML_TYPE_I32] = 1,
};
static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
static_assert(GGML_TYPE_COUNT == 20, "GGML_BLCK_SIZE is outdated");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = sizeof(float),
@ -3537,11 +3797,12 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q6_K] = sizeof(block_q6_K),
[GGML_TYPE_Q8_K] = sizeof(block_q8_K),
#endif
// [GGML_TYPE_QX_0], // QX_0 doesn't have a fixed type size
[GGML_TYPE_I8] = sizeof(int8_t),
[GGML_TYPE_I16] = sizeof(int16_t),
[GGML_TYPE_I32] = sizeof(int32_t),
};
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
static_assert(GGML_TYPE_COUNT == 20, "GGML_TYPE_SIZE is outdated");
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@ -3559,11 +3820,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_K] = "q5_K",
[GGML_TYPE_Q6_K] = "q6_K",
[GGML_TYPE_Q8_K] = "q8_K",
[GGML_TYPE_QX_0] = "qx_0",
[GGML_TYPE_I8] = "i8",
[GGML_TYPE_I16] = "i16",
[GGML_TYPE_I32] = "i32",
};
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
static_assert(GGML_TYPE_COUNT == 20, "GGML_TYPE_NAME is outdated");
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = false,
@ -3580,11 +3842,12 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_K] = true,
[GGML_TYPE_Q6_K] = true,
[GGML_TYPE_Q8_K] = true,
[GGML_TYPE_QX_0] = true,
[GGML_TYPE_I8] = false,
[GGML_TYPE_I16] = false,
[GGML_TYPE_I32] = false,
};
static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
static_assert(GGML_TYPE_COUNT == 20, "GGML_IS_QUANTIZED is outdated");
static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"NONE",
@ -3890,6 +4153,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_QX_0: wtype = GGML_TYPE_QX_0; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
}
@ -4266,7 +4530,14 @@ struct ggml_tensor * ggml_new_tensor_impl(
}
result->nb[0] = GGML_TYPE_SIZE[type];
result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
if (type == GGML_TYPE_QX_0) {
// QX_0 doesn't have a set stride size for a row; that value is stored in the "extra" part of the tensor
result->nb[1] = 0;
} else {
result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
}
for (int i = 2; i < GGML_MAX_DIMS; i++) {
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
}
@ -7719,6 +7990,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
{
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
} break;
@ -8027,6 +8299,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
{
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
} break;
@ -8154,6 +8427,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
default:
{
GGML_ASSERT(false);
@ -10189,13 +10463,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int i2 = i02;
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
void * src0_row;
if (type == GGML_TYPE_QX_0) {
if (ir > 0) {
src0_row = (void *) ((char *) src0->data + ((uint64_t *) src0->extra)[ir - 1]);
} else {
src0_row = (void *) ((char *) src0->data);
}
} else {
src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
assert(ne00 % 32 == 0);
}
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);
for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
}
@ -10231,6 +10514,7 @@ static void ggml_compute_forward_mul_mat(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
{
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break;
@ -10419,6 +10703,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
default:
{
GGML_ASSERT(false);
@ -10589,6 +10874,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
@ -11141,6 +11427,7 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_QX_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -11218,6 +11505,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_QX_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -16236,7 +16524,276 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
return (n/QK8_0*sizeof(block_q8_0));
}
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width) {
assert(n % QKX_0 == 0);
assert(tensor_width % QKX_0 == 0);
const int nb = n / QKX_0;
uint8_t * dst_8 = dst;
uint64_t dst_offset = 0;
// define max quantization errors for every bit precision
// i.e max_quantization_errors[1] holds max error for 1bit quantized weights
// max_quantization_errors[2] holds max error for 2bit quantized weights
// max_quantization_errors[3] holds max error for 3bit quantized weights
// etc.
//
// max quantization error here means that every single quantized weight is within
// said value (e.g. 0.004) from its original value
//
// this can be replaced with a max allowed RMSE, a set percentage of weights being within
// a certain range, etc... The current implementation here is pretty much just an example
float max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004};
// How maximum quantization error is implemented here:
//
// Each block holds both fp16 and "qbit" quantized weights mixed together arbitrarily.
// This mixing is handled by a few numbers at the start of each block, the bit of each number
// indicating if a given weight (corresponding to that bit) is stored on 16bit or is quantized.
//
// There is a metadata byte which indicates the qbit precision of the current block, and
// its values are in [1,2,3,4], but this can easily be extended to allow any other bit precisions,
// such as 5, 6, 9, 13 bits or anything else.
//
// To guarantee that each weight is within max_quantization_error, we first need to look at what range
// of values this allows us to have. Since we have "qbits" bits, then we have (1 << qbits) possible values
// the quantized weights can take. The maximum distance between two quantized points can be "2 * max_quantization_error"
// since any weight situated within these two points will be <= max_quantization_error of its closest point.
//
// A visual 2bit example would be: -->|<---->|<---->|<---->|<--
// Where "|" are the quantized points, and "-->" represents max_quantization_error on the number line.
//
// Any value outside this range will have to be kept on 16bit, since it cannot be within max_quantization_error
// of its quantized point.
//
//
// Note: Each block is kept byte-aligned for simplicity, which means that the number of 16bit weights and qbit weights
// in the bitstream has to be balanced such that the total number of bits is divisible by 8.
// e.g. If we have 3 4bit values and 253 16bit values, we will need to revert a 4bit value to 16bit in order
// to keep the total number of bits divisble by 8. If we were to quantize a weight instead, we would lose
// the "max_quantization_error" guarantee. However, each block doesn't need to remain byte-aligned, the requirement
// only holds for each row, so a big potential improvement could be made here, since we have quite a few unnecessary
// 16bit weights.
for (int i = 0; i < nb; i++) {
// each 64bit value holds binary data of whether the current weight (corresponding to a specific bit)
// is stored on 16bit or is quantized. "QKX_0 / 64" is here since we need multiple 64bit numbers if
// the QX_0 block is larger than 64 weights.
uint64_t fp16_indicators[QKX_0 / 64];
memset(fp16_indicators, 0, sizeof(uint64_t) * (QKX_0 / 64));
uint8_t qbits = QX_0_STARTING_QBITS;
float thresh = max_quantization_errors[qbits] * (1 << qbits);
int fp16_count = 0;
for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j];
if (fabsf(x) > thresh) {
// store this value on 16bits
fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64);
fp16_count += 1;
}
}
uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits;
while ((total_bits % 8) != 0) {
total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one (needed for a block's byte alignment)
}
float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1));
float mult_range = 2 * max_quantization_errors[qbits];
// The quantizer starts at a QX_0_STARTING_QBITS quantized block (e.g. 4bits), but then
// attempts to move to a lower precision defined by QX_0_START_OF_ATTEMPTED_QBITS.
// It keeps looking to see if 3, 2 or 1 bit precision leads to a smaller file size.
//
// The decrease in precision does not always lead to a smaller file when we need to maintain
// a fixed max quantization error, since lower bits mean a smaller value range, which might lead
// to more values being moved to 16bits, which might in the end actually increase our block's size.
//
// If values are very close to the mean, then a lower precision is more advantageous since we don't
// need a large quantization range, but otherwise it's likely more beneficial to stay at a higher precision.
// The loop below calculates this ideal trade-off for us!
for (uint8_t test_qbit = QX_0_START_OF_ATTEMPTED_QBITS; test_qbit >= 1; test_qbit--) {
// calculate the mean of non-fp16 values and define that as the center of the quantization range
float mean = 0;
for (int j = 0; j < QKX_0; j++) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float x_fp32 = src[i * QKX_0 + j];
mean += x_fp32;
}
}
mean /= (QKX_0 - fp16_count);
uint16_t total_fp16s_in_test_qbit = 0;
thresh = max_quantization_errors[test_qbit] * (1 << test_qbit);
for (int j = 0; j < QKX_0; j++) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float x = src[i * QKX_0 + j];
// new outlier found for our current qbit
if (x < mean - thresh || x > mean + thresh) {
total_fp16s_in_test_qbit += 1;
}
} else {
total_fp16s_in_test_qbit += 1;
}
}
uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit);
while ((total_bits_in_test_qbit % 8) != 0) {
total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one
}
if (total_bits_in_test_qbit < total_bits) {
total_bits = total_bits_in_test_qbit;
qbits = test_qbit;
min_value = mean - (max_quantization_errors[test_qbit] * ((1 << qbits) - 1));
mult_range = 2 * max_quantization_errors[test_qbit];
for (int j = 0; j < QKX_0; j++) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float x = src[i * QKX_0 + j];
// mark outlier as stored on 16bit
if (x < mean - thresh || x > mean + thresh) {
fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64);
fp16_count += 1;
}
}
}
uint16_t total_test_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits;
while ((total_test_bits % 8) != 0) {
total_test_bits += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one
}
GGML_ASSERT(total_bits == total_test_bits);
}
}
// keep converting the largest qbit values to fp16 until the block is byte-aligned
while (((QKX_0 - fp16_count) * qbits) % 8 != 0) {
float maxi = 0;
int target = -1;
for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j];
// weight is not on 16bit
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float diff = fabsf(x);
if (diff > maxi || target == -1) {
maxi = diff;
target = j;
}
}
}
GGML_ASSERT(target != -1);
fp16_indicators[target / 64] |= (uint64_t) 1 << (target % 64);
fp16_count += 1;
}
// store the current byte-offset of the current row, if "i" indicates that this is the first
// block of a row
if (((i * QKX_0) % tensor_width == 0) && i != 0) {
uint32_t row = (i * QKX_0) / tensor_width;
extra_data[row - 1] = dst_offset;
}
// write the fp16 indicators to dst
uint64_t * stored_fp16_indicators = (uint64_t *) (dst_8 + dst_offset);
for (int j = 0; j < QKX_0 / 64; j++) {
stored_fp16_indicators[j] = fp16_indicators[j];
}
dst_offset += (QKX_0 / 64) * sizeof(uint64_t);
// Each weight is stored as min_value + mult * quantized_weight
// Similar to Zero-point quantization, or Q4_1
// Write min value and multiplier to dst
*((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value);
dst_offset += sizeof(uint16_t);
*((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range);
dst_offset += sizeof(uint16_t);
// Store the "metadata" byte (for now it's just "qbits")
*((uint8_t*) (dst_8 + dst_offset)) = qbits;
dst_offset += sizeof(uint8_t);
// Store the quantization pivots / points
// IMPORTANT: Change qvals's size depending on the maximum qbits expected
GGML_ASSERT(qbits <= 8);
float qvals[1 << 8];
for (int j = 0; j < (1 << qbits); j++) {
qvals[j] = min_value + (mult_range * j);
}
uint64_t bit_offset = 0;
uint32_t * data = (uint32_t*) (dst_8 + dst_offset);
int fp16_count_chk = 0;
for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j];
if (fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) {
ggml_fp16_t x_f16 = ggml_fp32_to_fp16(x);
// store the full fp16 weight
write_bits(data, bit_offset, x_f16, 16);
bit_offset += 16;
fp16_count_chk += 1;
} else {
uint8_t q = 0;
float min_dist = fabsf(x - qvals[0]);
// find closest quantization point
for (int iv = 0; iv < (1 << qbits); iv++) {
float dist = fabsf(x - qvals[iv]);
if (dist < min_dist) {
q = iv;
min_dist = dist;
}
}
write_bits(data, bit_offset, q, qbits);
bit_offset += qbits;
}
}
// check that the reported fp16_count is coherent with the bits stored in fp16_indicators
GGML_ASSERT(fp16_count == fp16_count_chk);
// check that the number of bits from quantized values is divisible by 8
GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0);
dst_offset += ((QKX_0 - fp16_count) * qbits) / 8;
dst_offset += fp16_count * 2;
}
// store the total size of the tensor as the last element of extra_data
extra_data[n / tensor_width - 1] = dst_offset;
return dst_offset;
}
// Pass in additional information such as the tensor's "extra_data" and width, since QX_0 needs this info. We can't pass in a pointer to
// a ggml_tensor (since none exists where quantize_chunk is created), nor to llama_load_tensor since ggml.c doesn't have access to the struct
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width) {
size_t result = 0;
switch (type) {
case GGML_TYPE_Q4_0:
@ -16301,6 +16858,10 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
} break;
#endif
case GGML_TYPE_QX_0:
{
result = ggml_quantize_qx_0(src, dst, n, hist, extra_data, tensor_width);
} break;
default:
assert(false);
}

5
ggml.h
View file

@ -248,6 +248,7 @@ extern "C" {
GGML_TYPE_Q5_K = 13,
GGML_TYPE_Q6_K = 14,
GGML_TYPE_Q8_K = 15,
GGML_TYPE_QX_0 = 16,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -276,6 +277,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
GGML_FTYPE_MOSTLY_QX_0 = 15, // except 1d tensors
};
// available tensor operations:
@ -1135,13 +1137,14 @@ extern "C" {
// quantization
//
GGML_API size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width);
GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width);
//
// system info

129
llama.cpp
View file

@ -342,8 +342,11 @@ struct llama_load_tensor_shard {
enum ggml_type type;
size_t file_idx;
size_t file_off;
size_t extra_data_file_off;
void calc_size() {
// For QX_0, the size is manually written-in, since it comes from extra_data
GGML_ASSERT(type != GGML_TYPE_QX_0);
size = llama_calc_tensor_size(ne, type);
}
};
@ -364,6 +367,7 @@ struct llama_load_tensor {
size_t size;
struct ggml_tensor * ggml_tensor = NULL;
uint8_t * data;
uint64_t * extra_data = NULL;
llama_load_tensor(const std::string & name) : name(name) {}
@ -424,7 +428,18 @@ struct llama_load_tensor {
}
void calc_size() {
size = llama_calc_tensor_size(ne, type);
// For QX_0 the size comes from extra_data, but since extra_data might not be initialized here
// we can take it from the shard instead
if (type == GGML_TYPE_QX_0) {
GGML_ASSERT(shards.size() == 1);
GGML_ASSERT(ne.size() == 2);
size = shards.at(0).size;
GGML_ASSERT(size != 0);
} else {
size = llama_calc_tensor_size(ne, type);
}
}
};
@ -520,6 +535,7 @@ struct llama_file_loader {
shard.ne.resize(n_dims);
file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims);
std::string name = file.read_string(name_len);
if (n_dims < 1 || n_dims > 2) {
throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims));
}
@ -536,6 +552,7 @@ struct llama_file_loader {
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
break;
default: {
throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type));
@ -546,12 +563,38 @@ struct llama_file_loader {
// skip to the next multiple of 32 bytes
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
}
if (shard.type == GGML_TYPE_QX_0) {
shard.extra_data_file_off = file.tell();
// seek until before the last element of extra_data
file.seek(sizeof(uint64_t) * (shard.ne[1] - 1), SEEK_CUR);
// get the tensor's size from here
uint64_t tensor_size = 0;
file.read_raw(&tensor_size, sizeof(uint64_t));
shard.size = tensor_size;
// realign, just in case extra_data isn't a multiple of 32B
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
} else {
shard.extra_data_file_off = 0;
}
shard.file_idx = file_idx;
shard.file_off = file.tell();
shard.calc_size();
if (shard.type != GGML_TYPE_QX_0) {
shard.calc_size();
}
file.seek(shard.size, SEEK_CUR);
// QX_0's data may not be 32-byte aligned
if (shard.type == GGML_TYPE_QX_0) {
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
}
auto it = tensors_map.name_to_idx.find(name);
size_t idx;
if (it != tensors_map.name_to_idx.end()) {
@ -602,7 +645,9 @@ struct llama_file_saver {
file.write_raw(&token_score.score, sizeof(token_score.score));
}
}
void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
// pass extra_data by reference to avoid excessive copying
void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size, llama_buffer & extra_data) {
switch (new_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@ -616,6 +661,7 @@ struct llama_file_saver {
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_QX_0:
break;
default: LLAMA_ASSERT(false);
}
@ -624,9 +670,29 @@ struct llama_file_saver {
file.write_u32(new_type);
file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size());
file.write_raw(tensor.name.data(), tensor.name.size());
size_t tensor_size;
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
LLAMA_ASSERT(new_size == llama_calc_tensor_size(tensor.ne, new_type));
// The tensor's size for QX_0 is stored in the last element of extra_data
if (new_type == GGML_TYPE_QX_0) {
file.write_raw(extra_data.addr, sizeof(uint64_t) * tensor.ne[1]);
tensor_size = ((uint64_t *) extra_data.addr)[tensor.ne[1] - 1];
// realign, just in case extra_data isn't a multiple of 32B
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
} else {
tensor_size = llama_calc_tensor_size(tensor.ne, new_type);
}
LLAMA_ASSERT(new_size == tensor_size);
file.write_raw(new_data, new_size);
// QX_0 data may not be 32-byte aligned
if (new_type == GGML_TYPE_QX_0) {
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
}
}
};
@ -666,7 +732,7 @@ struct llama_model_loader {
bool alignment_prevents_mmap() {
for (const llama_load_tensor & lt : tensors_map.tensors) {
for (const llama_load_tensor_shard & shard : lt.shards) {
if (shard.file_off & 3) {
if ((shard.file_off & 3)) {
return true;
}
}
@ -725,6 +791,7 @@ struct llama_model_loader {
tensor->backend = backend;
lt.ggml_tensor = tensor;
num_ggml_tensors_created++;
return tensor;
}
@ -771,6 +838,13 @@ struct llama_model_loader {
switch(lt.ggml_tensor->backend) {
case GGML_BACKEND_CPU:
lt.ggml_tensor->data = lt.data;
if (lt.type == GGML_TYPE_QX_0) {
// QX_0 uses the extra field to store byte offsets in *data for each row except row 0
// (so extra[0] stores where row 1 starts, extra[1] is for row 2, and the last element
// in extra stores the total tensor size)
lt.ggml_tensor->extra = lt.extra_data;
}
if (use_mmap && lmlock) {
lock_size += lt.size;
lmlock->grow_to(lock_size);
@ -801,9 +875,17 @@ struct llama_model_loader {
}
void load_data_for(llama_load_tensor & lt) {
// QX_0 only supports mmap
GGML_ASSERT(use_mmap || lt.type != GGML_TYPE_QX_0);
if (use_mmap) {
LLAMA_ASSERT(lt.shards.size() == 1);
lt.data = (uint8_t *) mapping->addr + lt.shards.at(0).file_off;
if (lt.shards.at(0).extra_data_file_off != 0) {
lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off);
}
} else if (lt.split_type == SPLIT_NONE) {
llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file;
file.seek(lt.shards.at(0).file_off, SEEK_SET);
@ -988,6 +1070,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K";
case LLAMA_FTYPE_MOSTLY_QX_0: return "mostly QX_0";
default: return "unknown, may not work";
}
}
@ -1665,6 +1748,8 @@ static bool llama_eval_internal(
lctx.n_p_eval += N;
}
// fprintf(stderr, "\nmodel eval time: %ldms\n", (ggml_time_us() - t_start_us) / 1000);
// fflush(stderr);
return true;
}
@ -2309,6 +2394,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
case LLAMA_FTYPE_MOSTLY_QX_0: quantized_type = GGML_TYPE_QX_0; break;
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
}
@ -2316,6 +2402,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
nthread = std::thread::hardware_concurrency();
}
// multithreaded QX_0 quantization is not compatible with the current multithreaded quantization impl.
// because, since blocks have an unknown size in bytes, we cannot section the output data in exact
// chunks assigned to 1 thread. Multithreading would technically only be possible if we quantize
// multiple entire tensors at once, but the overall implementation doesn't seem to allow that to be done easily
if (quantized_type == GGML_TYPE_QX_0) {
nthread = 1;
printf("Setting nthread to 1 due to the implementation for QX_0 quantization being single-threaded.\n");
}
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false,
/*vocab_only*/ false));
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), params->ftype);
@ -2363,12 +2458,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (!params->quantize_output_tensor && tensor.name == "output.weight") {
quantize = false;
}
// Allow only attention and FFN matrices to be quantized under QX_0, since they only require vec_dot
// to be implemented. Output weights and other matrices require more fuctions to be implemented, so
// for simplicity we'll only quantize attn and ffn for now.
if (quantized_type == GGML_TYPE_QX_0) {
if (tensor.name.find("attention") == std::string::npos && tensor.name.find("feed_forward") == std::string::npos) {
quantize = false;
}
}
quantize = quantize && quantized_type != tensor.type;
enum ggml_type new_type;
void * new_data;
size_t new_size;
llama_buffer work;
llama_buffer extra_data;
if (!quantize) {
new_type = tensor.type;
@ -2421,11 +2527,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
new_data = work.addr;
std::vector<int64_t> hist_cur(1 << 4, 0);
if (new_type == GGML_TYPE_QX_0) {
extra_data.resize(sizeof(uint64_t) * tensor.ne[1]);
}
int chunk_size = 32 * 512;
const int nchunk = (nelements + chunk_size - 1)/chunk_size;
const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
if (nthread_use < 2) {
new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data());
new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data(), (uint64_t *) extra_data.addr, tensor.ne[0]);
} else {
size_t counter = 0;
new_size = 0;
@ -2449,7 +2560,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (local_hist.empty()) {
local_hist.resize(hist_cur.size(), 0);
}
local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data());
// pass in NULL for extra_data, since it's only required for QX_0, which doesn't support quantized threading
local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data(), NULL, 0);
}
};
if ((int) workers.size() < nthread_use - 1) {
@ -2480,7 +2593,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
}
total_size_org += tensor.size;
total_size_new += new_size;
file_saver.write_tensor(tensor, new_type, new_data, new_size);
file_saver.write_tensor(tensor, new_type, new_data, new_size, extra_data);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);

View file

@ -113,6 +113,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
LLAMA_FTYPE_MOSTLY_QX_0 = 19, // except 1d tensors
};
// model quantization parameters