diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 27a70241c..d42ac9181 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -571,14 +571,15 @@ typedef struct { } block_iq4_nl; static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); -#define QK4_XS 64 #define QR4_XS 2 -#define QI4_XS (QK4_XS / (4*QR4_XS)) +#define QI4_XS (QK_K / (4*QR4_XS)) typedef struct { half d; - uint8_t qs[QK4_XS/2]; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; } block_iq4_xs; -static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + QK4_XS/2, "wrong iq4_xs block size/padding"); +static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); #define WARP_SIZE 32 @@ -2439,19 +2440,14 @@ template static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; - //const block_iq4_xs * x = (const block_iq4_xs *) vx + i*(QK_K/QK4_XS); const block_iq4_xs * x = (const block_iq4_xs *)vx; const int tid = threadIdx.x; const int il = tid/8; // 0...3 const int ib = tid%8; // 0...7 - const int ib32 = i*(QK_K/32) + ib; - dst_t * y = yy + 32*ib32 + 4*il; - const uint8_t * q4 = x[ib32/(QK4_XS/32)].qs + 16*(ib32%(QK4_XS/32)) + 4*il; - const float d = (float)x[ib32/(QK4_XS/32)].d; - //dst_t * y = yy + i*QK_K + 32*ib + 4*il; - //const uint8_t * q4 = x->qs + 16*(ib%(QK4_XS/32)) + 4*il; - //const float d = (float)x->d; + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); for (int j = 0; j < 4; ++j) { y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; @@ -9420,7 +9416,7 @@ static void ggml_cuda_op_mul_mat_vec_q( (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_q_cuda + mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ3_S: diff --git a/ggml-quants.c b/ggml-quants.c index df94ba96e..15d027dc6 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4226,8 +4226,8 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, } void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) { - assert(k % QK4_XS == 0); - const int nb = k / QK4_XS; + assert(k % QK_K == 0); + const int nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -4235,10 +4235,12 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, const float d = GGML_FP16_TO_FP32(x[i].d); - for (int iq = 0; iq < QK4_XS/32; ++iq) { + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); + const float dl = d * (ls - 32); for (int j = 0; j < 16; ++j) { - y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; - y[j+16] = d * kvalues_iq4nl[qs[j] >> 4]; + y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; } y += 32; qs += 16; @@ -10448,13 +10450,13 @@ void ggml_vec_dot_iq4_xs_q8_0(int n, float * restrict s, size_t bs, const void * UNUSED(bx); UNUSED(by); UNUSED(bs); - assert(n % QK4_XS == 0); + assert(n % QK_K == 0); static_assert(QK8_0 == 32, "QK8_0 must be 32"); const block_iq4_xs * restrict x = vx; const block_q8_0 * restrict y = vy; - const int nb = n / QK4_XS; + const int nb = n / QK_K; #if defined z__ARM_NEON const int8x16_t values = vld1q_s8(kvalues_iq4nl); @@ -10524,20 +10526,32 @@ void ggml_vec_dot_iq4_xs_q8_0(int n, float * restrict s, size_t bs, const void * #else float sumf = 0; - for (int ib = 0; ib < nb; ++ib) { - const float d4 = GGML_FP16_TO_FP32(x[ib].d); - const uint8_t * qs = x[ib].qs; - for (int iq = 0; iq < QK4_XS/32; ++iq) { - const float d = GGML_FP16_TO_FP32(y[iq].d)*d4; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4 = GGML_FP16_TO_FP32(x[ibl].d); + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = GGML_FP16_TO_FP32(y[0].d)*d4*(ls1 - 32); + const float d2 = GGML_FP16_TO_FP32(y[1].d)*d4*(ls2 - 32); int sumi1 = 0, sumi2 = 0; for (int j = 0; j < 16; ++j) { - sumi1 += y[iq].qs[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += y[iq].qs[j+16] * kvalues_iq4nl[qs[j] >> 4]; + sumi1 += y[0].qs[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += y[0].qs[j+16] * kvalues_iq4nl[qs[j] >> 4]; } - sumf += d * (sumi1 + sumi2); + sumf += d1 * (sumi1 + sumi2); qs += 16; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += y[1].qs[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += y[1].qs[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + y += 2; } - y += QK4_XS/32; } *s = sumf; #endif @@ -12139,53 +12153,44 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } -static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT xb, - ggml_fp16_t * dh, uint8_t * q4, - float * weight, uint8_t * L, +static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x, + ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, + float * scales, float * weight, uint8_t * L, const int8_t * values, const float * quant_weights) { const int ntry = 7; float sigma2 = 0; - for (int j = 0; j < block_size; ++j) sigma2 += xb[j]*xb[j]; - sigma2 *= 2.f/block_size; - - memset(q4, 0, block_size/2); + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; + memset(q4, 0, super_block_size/2); dh[0] = GGML_FP32_TO_FP16(0.f); - if (quant_weights) { - for (int j = 0; j < block_size; ++j) weight[j] = quant_weights[j] * sqrtf(sigma2 + xb[j]*xb[j]); - } else { - for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; - } - float amax = 0, max = 0; - for (int j = 0; j < block_size; ++j) { - float ax = fabsf(xb[j]); - if (ax > amax) { - amax = ax; max = xb[j]; + + float max_scale = 0, amax_scale = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = x + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; } - } - if (!amax) { - return; - } - float d = -max/values[0]; - float id = 1/d; - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < block_size; ++j) { - float al = id*xb[j]; - int l = best_index_int8(16, values, al); - float q = values[l]; - float w = weight[j]; - sumqx += w*q*xb[j]; - sumq2 += w*q*q; - } - float best_id = id; - d = sumqx/sumq2; - float best = d*sumqx; - for (int itry = -ntry; itry <= ntry; ++itry) { - id = (itry + values[0])/max; - sumqx = sumq2 = 0; + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = -max/values[0]; + float id = 1/d; + float sumqx = 0, sumq2 = 0; for (int j = 0; j < block_size; ++j) { float al = id*xb[j]; int l = best_index_int8(16, values, al); @@ -12194,17 +12199,62 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE sumqx += w*q*xb[j]; sumq2 += w*q*q; } - if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { - d = sumqx/sumq2; best = d * sumqx; - best_id = id; + d = sumqx/sumq2; + float best = d*sumqx; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + } + } + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; } } - dh[0] = GGML_FP32_TO_FP16(d); - for (int j = 0; j < block_size; ++j) { - L[j] = best_index_int8(16, values, best_id*xb[j]); + + if (super_block_size/block_size > 1) { + int nb = super_block_size/block_size; + memset(scales_h, 0, ((nb+3)/4)*sizeof(uint16_t)); + float d = -max_scale/32; + dh[0] = GGML_FP32_TO_FP16(d); + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_int8(16, values, idl*xb[j]); + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) scales_l[ib/2] = l_l; + else scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + } else { + dh[0] = GGML_FP32_TO_FP16(scales[0]); + float id = scales[0] ? 1/scales[0] : 0; + for (int j = 0; j < super_block_size; ++j) { + L[j] = best_index_int8(16, values, id*x[j]); + } } - for (int i = 0; i < block_size/32; ++i) { + for (int i = 0; i < super_block_size/32; ++i) { for (int j = 0; j < 16; ++j) { q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); } @@ -12218,11 +12268,15 @@ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, i char * qrow = (char *)dst; uint8_t L[QK4_NL]; float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; for (int row = 0; row < nrow; ++row) { block_iq4_nl * iq4 = (block_iq4_nl *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; - quantize_row_iq4_nl_impl(QK4_NL, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw); + quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, qw); } src += n_per_row; qrow += nblock*sizeof(block_iq4_nl); @@ -12243,16 +12297,18 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - GGML_ASSERT(n_per_row%QK4_XS == 0); - int nblock = n_per_row/QK4_XS; + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; char * qrow = (char *)dst; - uint8_t L[QK4_XS]; - float weight[QK4_XS]; + uint8_t L[QK_K]; + float weight[32]; + float scales[QK_K/32]; for (int row = 0; row < nrow; ++row) { block_iq4_xs * iq4 = (block_iq4_xs *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { - const float * qw = quant_weights ? quant_weights + QK4_XS*ibl : NULL; - quantize_row_iq4_nl_impl(QK4_XS, src + QK4_XS*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw); + const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; + quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l, + scales, weight, L, kvalues_iq4nl, qw); } src += n_per_row; qrow += nblock*sizeof(block_iq4_xs); diff --git a/ggml-quants.h b/ggml-quants.h index 37e245ef3..57bed25be 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -230,13 +230,13 @@ typedef struct { } block_iq4_nl; static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); -#define QK4_XS 64 typedef struct { ggml_fp16_t d; - uint8_t qs[QK4_XS/2]; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; } block_iq4_xs; -static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + QK4_XS/2, "wrong iq4_xs block size/padding"); -static_assert(QK4_XS%32 == 0, "QK4_XS must be a multiple of 32"); +static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); #ifdef __cplusplus extern "C" { diff --git a/ggml.c b/ggml.c index ee79039dc..5b4bf8b28 100644 --- a/ggml.c +++ b/ggml.c @@ -728,7 +728,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { }, [GGML_TYPE_IQ4_XS] = { .type_name = "iq4_xs", - .blck_size = QK4_XS, + .blck_size = QK_K, .type_size = sizeof(block_iq4_xs), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, diff --git a/llama.cpp b/llama.cpp index aac12f72e..00fd3ffb4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10903,8 +10903,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; } } - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) { - if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K; + else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS)) { + if (!qs.has_imatrix) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {