k_quants: switch Q4_K to 4-bit scales when QK_K = 64

Here the loss in accuracy is greater than for Q3_K,
 but the Q4_K points still move further to the left on
 the perplexity vs size curve.
This commit is contained in:
Iwan Kawrakow 2023-06-24 15:44:23 +03:00
parent aeefd4e781
commit ce19b965f0
4 changed files with 98 additions and 54 deletions

View file

@ -147,10 +147,11 @@ typedef struct {
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
half d[2*QK_K/32]; // super-block scales/mins half d[2]; // super-block scales/mins
uint8_t scales[2]; // 4-bit block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K; } block_q4_K;
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
#else #else
typedef struct { typedef struct {
half d; // super-block scale for quantized scales half d; // super-block scale for quantized scales
@ -503,8 +504,10 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const uint8_t * q = x[i].qs; const uint8_t * q = x[i].qs;
float * y = yy + i*QK_K; float * y = yy + i*QK_K;
y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1]; const float d = (float)x[i].d[0];
y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3]; const float m = (float)x[i].d[1];
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
#endif #endif
} }
@ -874,20 +877,25 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
#else #else
const int step = tid * K_QUANTS_PER_ITERATION; const int step = tid * K_QUANTS_PER_ITERATION;
uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16;
float tmp = 0; float tmp = 0;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const uint8_t * q = x[i].qs + step; const uint8_t * q = x[i].qs + step;
const float * y = yy + i*QK_K + step; const float * y = yy + i*QK_K + step;
const half2 * d = (const half2 *)x[i].d; const uint16_t * a = (const uint16_t *)x[i].scales;
float2 df1 = __half22float2(d[0]); aux16[0] = a[0] & 0x0f0f;
float2 df2 = __half22float2(d[1]); aux16[1] = (a[0] >> 4) & 0x0f0f;
const float d = (float)x[i].d[0];
const float m = (float)x[i].d[1];
float sum = 0.f; float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y) + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y) + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
+ y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y); + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
} }
tmp += sum; tmp += sum;
} }

View file

@ -635,14 +635,11 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
const int nb = k / QK_K; const int nb = k / QK_K;
uint8_t L[QK_K]; uint8_t L[QK_K];
#if QK_K == 256
float mins[QK_K/32]; float mins[QK_K/32];
float scales[QK_K/32]; float scales[QK_K/32];
#endif
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
#if QK_K == 256
float max_scale = 0; // as we are deducting the min, scales are always positive float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0; float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
@ -657,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
} }
} }
#if QK_K == 256
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f; float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
@ -689,23 +687,37 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
} }
} }
#else #else
for (int j = 0; j < QK_K/32; ++j) { const float s_factor = 15.f;
float min; float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5); float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); int d1 = nearest_int(inv_scale*scales[0]);
y[i].d[2*j+1] = ggml_fp32_to_fp16(min); int m1 = nearest_int(inv_min*mins[0]);
} int d2 = nearest_int(inv_scale*scales[1]);
int m2 = nearest_int(inv_min*mins[1]);
y[i].scales[0] = d1 | (m1 << 4);
y[i].scales[1] = d2 | (m2 << 4);
y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
float sumlx = 0;
int suml2 = 0;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); const uint8_t sd = y[i].scales[j] & 0xF;
const uint8_t sm = y[i].scales[j] >> 4;
const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
if (!d) continue; if (!d) continue;
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
for (int ii = 0; ii < 32; ++ii) { for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d); int l = nearest_int((x[32*j + ii] + m)/d);
l = MAX(0, MIN(15, l)); l = MAX(0, MIN(15, l));
L[32*j + ii] = l; L[32*j + ii] = l;
sumlx += (x[32*j + ii] + m)*l*sd;
suml2 += l*l*sd*sd;
} }
} }
if (suml2) {
y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
}
#endif #endif
uint8_t * q = y[i].qs; uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) { for (int j = 0; j < QK_K; j += 64) {
@ -743,8 +755,10 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
q += 32; is += 2; q += 32; is += 2;
} }
#else #else
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); const float dall = ggml_fp16_to_fp32(x[i].d[0]);
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); const float mall = ggml_fp16_to_fp32(x[i].d[1]);
const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
for (int l = 0; l < 32; ++l) { for (int l = 0; l < 32; ++l) {
y[l+ 0] = d1 * (q[l] & 0xF) - m1; y[l+ 0] = d1 * (q[l] & 0xF) - m1;
y[l+32] = d2 * (q[l] >> 4) - m2; y[l+32] = d2 * (q[l] >> 4) - m2;
@ -1953,7 +1967,6 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i m3 = _mm256_set1_epi8(3); const __m256i m3 = _mm256_set1_epi8(3);
const __m256i m1 = _mm256_set1_epi8(1); const __m256i m1 = _mm256_set1_epi8(1);
const __m256i m8 = _mm256_set1_epi16(8);
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2182,7 +2195,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
#if QK_K == 256
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
@ -2192,10 +2204,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux; utmp[2] = uaux;
utmp[0] &= kmask1; utmp[0] &= kmask1;
#else
// TODO
const float d = 0; const float dmin = 0;
#endif
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
@ -2326,15 +2334,22 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
float sum_mins = 0.f; float sum_mins = 0.f;
uint16_t aux16[2];
const uint8_t * restrict scales = (const uint8_t *)aux16;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
const float32x4_t dsc = vcvt_f32_f16(vld1_f16(x[i].d)); const uint16_t * restrict a = (const uint16_t *)x[i].scales;
float summ = vgetq_lane_f32(dsc, 1) * (y[i].bsums[0] + y[i].bsums[1]) aux16[0] = a[0] & 0x0f0f;
+ vgetq_lane_f32(dsc, 3) * (y[i].bsums[2] + y[i].bsums[3]); aux16[1] = (a[0] >> 4) & 0x0f0f;
sum_mins += y[i].d * summ;
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
sum_mins += y[i].d * (float)x[i].d[1] * summi;
const float d = y[i].d * (float)x[i].d[0];
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
@ -2344,13 +2359,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const float sumf1 = vaddvq_s32(p1) * vgetq_lane_f32(dsc, 0); const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
const float sumf2 = vaddvq_s32(p2) * vgetq_lane_f32(dsc, 2); const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
#else #else
q8bytes = vld1q_s8_x4(q8); q8bytes = vld1q_s8_x4(q8);
@ -2360,7 +2375,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
float sumf1 = vaddvq_s16(vaddq_s16(p0, p1)) * vgetq_lane_f32(dsc, 0); int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
@ -2368,10 +2383,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
float sumf2 = vaddvq_s16(vaddq_s16(p2, p3)) * vgetq_lane_f32(dsc, 2); int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
#endif #endif
sumf += y[i].d * (sumf1 + sumf2); sumf += d * (sumi1 + sumi2);
} }
@ -2380,16 +2395,25 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#elif defined __AVX2__ #elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i m1 = _mm256_set1_epi16(1);
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
float summs = 0; float summs = 0;
uint16_t aux16[2];
const uint8_t * scales = (const uint8_t *)aux16;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
summs += y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
const __m256 vd = _mm256_set1_ps(d);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
@ -2404,11 +2428,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
const __m256i p32l = _mm256_madd_epi16(m1, p16l); const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
const __m256i p32h = _mm256_madd_epi16(m1, p16h); acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[0])), _mm256_cvtepi32_ps(p32l), acc); const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[2])), _mm256_cvtepi32_ps(p32h), acc); acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
} }
@ -2421,6 +2445,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
float sums [8]; float sums [8];
memset(sums, 0, 8*sizeof(float)); memset(sums, 0, 8*sizeof(float));
uint16_t s16[2];
const uint8_t * restrict scales = (const uint8_t *)s16;
float sumf = 0; float sumf = 0;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
@ -2429,16 +2456,21 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + const uint16_t * restrict b = (const uint16_t *)x[i].scales;
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); s16[0] = b[0] & 0x0f0f;
s16[1] = (b[0] >> 4) & 0x0f0f;
sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]);
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
q8 += 16; a += 16; q8 += 16; a += 16;
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
q8 += 16; a += 16; q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]); const float dl = d * scales[j];
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
} }
} }
for (int l = 0; l < 8; ++l) sumf += sums[l]; for (int l = 0; l < 8; ++l) sumf += sums[l];

View file

@ -59,10 +59,11 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 +
// Effectively 4.5 bits per weight // Effectively 4.5 bits per weight
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins ggml_fp16_t d[2]; // super-block scales/mins
uint8_t scales[2]; // 4-bit block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K; } block_q4_K;
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
#else #else
typedef struct { typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales ggml_fp16_t d; // super-block scale for quantized scales
@ -70,8 +71,8 @@ typedef struct {
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K; } block_q4_K;
#endif
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
#endif
// 5-bit quantization // 5-bit quantization
// 16 blocks of 32 elements each // 16 blocks of 32 elements each

View file

@ -2533,12 +2533,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
++i_attention_wv; ++i_attention_wv;
} else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
//else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
++i_feed_forward_w2; ++i_feed_forward_w2;
} else if (tensor.name.find("attention.wo.weight") != std::string::npos) { } else if (tensor.name.find("attention.wo.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;