k_quants: swicth Q3_K to 4-bit scales when QK_K = 64

Otherwise there isn't much benefit from this
quantization type. There is some very slight loss
in accuracy, but we reduce size by ~7%.
E.g., for OpenLLaMA-3B, Q3_K_S perplexity is
8.6131 with 8-bit scales and 8.6352 with 4-bit,
while file size decreases from 1.53G to 1.44G.
This commit is contained in:
Iwan Kawrakow 2023-06-24 09:57:50 +03:00
parent 88412a1aa0
commit aeefd4e781
4 changed files with 86 additions and 50 deletions

View file

@ -137,13 +137,13 @@ typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits uint8_t qs[QK_K/4]; // quants - low 2 bits
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
int8_t scales[K_SCALE_SIZE]; // scales, quantized with 8 bits uint8_t scales[2]; // scales, quantized with 8 bits
#else #else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif #endif
half d; // super-block scale half d; // super-block scale
} block_q3_K; } block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
@ -448,8 +448,13 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
const uint8_t h = x[i].hmask[in] >> (2*is + im); const uint8_t h = x[i].hmask[in] >> (2*is + im);
const float d = (float)x[i].d; const float d = (float)x[i].d;
y[ 0] = d * x[i].scales[is+0] * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); if (is == 0) {
y[32] = d * x[i].scales[is+2] * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
} else {
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
}
#endif #endif
} }
@ -776,7 +781,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const float * y = yy + i * QK_K + offset; const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset; const uint8_t * q = x[i].qs + offset;
const int8_t * s = x[i].scales; const uint8_t * s = x[i].scales;
const float dall = (float)x[i].d; const float dall = (float)x[i].d;
@ -784,10 +789,10 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
const uint8_t hl = x[i].hmask[im+l] >> in; const uint8_t hl = x[i].hmask[im+l] >> in;
const uint8_t ql = q[l]; const uint8_t ql = q[l];
sum += y[l+ 0] * dall * s[0] * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+ y[l+16] * dall * s[1] * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+ y[l+32] * dall * s[2] * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+ y[l+48] * dall * s[3] * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
} }
tmp += sum; tmp += sum;
} }

View file

@ -799,7 +799,7 @@ typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits uint8_t qs[QK_K/4]; // quants - low 2 bits
#if QK_K == 64 #if QK_K == 64
int8_t scales[K_SCALE_SIZE]; uint8_t scales[2];
#else #else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif #endif
@ -970,10 +970,10 @@ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, i
device const uint8_t * q = x[i].qs; device const uint8_t * q = x[i].qs;
device const uint8_t * hm = x[i].hmask; device const uint8_t * hm = x[i].hmask;
const float d1 = d_all * x[i].scales[0]; const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
const float d2 = d_all * x[i].scales[1]; const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
const float d3 = d_all * x[i].scales[2]; const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
const float d4 = d_all * x[i].scales[3]; const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
for (int l = 0; l < 8; ++l) { for (int l = 0; l < 8; ++l) {
uint8_t h = hm[l]; uint8_t h = hm[l];
@ -1417,10 +1417,10 @@ kernel void kernel_mul_mat_q3_K_f32(
device const uint8_t * h = x[i].hmask + in; device const uint8_t * h = x[i].hmask + in;
device const float * y = yy + i * QK_K + il; device const float * y = yy + i * QK_K + il;
const float d1 = d_all * x[i].scales[0]; const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
const float d2 = d_all * x[i].scales[1]; const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
const float d3 = d_all * x[i].scales[2]; const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
const float d4 = d_all * x[i].scales[3]; const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
const uint8_t hm = h[l] >> im; const uint8_t hm = h[l] >> im;

View file

@ -469,21 +469,24 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
} }
#else #else
if (max_scale) { if (max_scale) {
float iscale = -128.f/max_scale; float iscale = -8.f/max_scale;
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; j+=2) {
int l = nearest_int(iscale*scales[j]); int l1 = nearest_int(iscale*scales[j]);
l = MAX(-128, MIN(127, l)); l1 = 8 + MAX(-8, MIN(7, l1));
y[i].scales[j] = l; int l2 = nearest_int(iscale*scales[j+1]);
l2 = 8 + MAX(-8, MIN(7, l2));
y[i].scales[j/2] = l1 | (l2 << 4);
} }
y[i].d = ggml_fp32_to_fp16(1/iscale); y[i].d = ggml_fp32_to_fp16(1/iscale);
} else { } else {
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; j+=2) {
y[i].scales[j] = 0; y[i].scales[j/2] = 0;
} }
y[i].d = ggml_fp32_to_fp16(0.f); y[i].d = ggml_fp32_to_fp16(0.f);
} }
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
if (!d) { if (!d) {
continue; continue;
} }
@ -587,10 +590,10 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
const uint8_t * restrict q = x[i].qs; const uint8_t * restrict q = x[i].qs;
const uint8_t * restrict hm = x[i].hmask; const uint8_t * restrict hm = x[i].hmask;
const float d1 = d_all * x[i].scales[0]; const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
const float d2 = d_all * x[i].scales[1]; const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
const float d3 = d_all * x[i].scales[2]; const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
const float d4 = d_all * x[i].scales[3]; const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
for (int l=0; l<8; ++l) { for (int l=0; l<8; ++l) {
uint8_t h = hm[l]; uint8_t h = hm[l];
@ -1889,6 +1892,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
int8x16x4_t q3bytes; int8x16x4_t q3bytes;
uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16;
float sum = 0; float sum = 0;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -1899,8 +1905,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t q3bits = vld1q_u8(x[i].qs); const uint8x16_t q3bits = vld1q_u8(x[i].qs);
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
const int8_t * restrict scales = x[i].scales; const uint16_t a = *(const uint16_t *)x[i].scales;
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[1] * y[i].bsums[1] + scales[2] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
for (int j = 0; j < 4; ++j) scales[j] -= 8;
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
const float d = y[i].d * (float)x[i].d; const float d = y[i].d * (float)x[i].d;
@ -1917,8 +1928,8 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[1]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[2]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
#else #else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -1929,7 +1940,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[1] + vaddvq_s16(p2) * scales[2] + vaddvq_s16(p3) * scales[3]; isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
#endif #endif
sum += d * isum; sum += d * isum;
@ -1942,11 +1953,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i m3 = _mm256_set1_epi8(3); const __m256i m3 = _mm256_set1_epi8(3);
const __m256i m1 = _mm256_set1_epi8(1); const __m256i m1 = _mm256_set1_epi8(1);
const __m256i m8 = _mm256_set1_epi16(8);
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
uint64_t aux64; uint64_t aux64;
uint16_t aux16[2];
const int8_t * aux8 = (const int8_t *)aux16;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
@ -1954,14 +1969,18 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict q3 = x[i].qs; const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); const uint16_t a = *(const uint16_t *)x[i].scales;
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
// Set up scales
memcpy(&aux64, x[i].hmask, 8); memcpy(&aux64, x[i].hmask, 8);
__m256i q3h_0 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 3, aux64 >> 2), _mm_set_epi64x(aux64 >> 1, aux64 >> 0)); const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
__m256i q3h_1 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 7, aux64 >> 6), _mm_set_epi64x(aux64 >> 5, aux64 >> 4)); __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
@ -1969,8 +1988,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
// prepare low and high bits // prepare low and high bits
const __m256i q3l_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits), m3); const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
const __m256i q3l_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 6), _mm_srli_epi16(q3bits, 4)), m3); const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
// load Q8 quants // load Q8 quants
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
@ -2007,6 +2027,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
int16_t aux16[8]; int16_t aux16[8];
float sums [8]; float sums [8];
int32_t aux32[8]; int32_t aux32[8];
int32_t scales[4];
memset(sums, 0, 8*sizeof(float)); memset(sums, 0, 8*sizeof(float));
float sumf = 0; float sumf = 0;
@ -2026,14 +2047,18 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
} }
scales[0] = (x[i].scales[0] & 0xF) - 8;
scales[1] = (x[i].scales[0] >> 4) - 8;
scales[2] = (x[i].scales[1] & 0xF) - 8;
scales[3] = (x[i].scales[1] >> 4) - 8;
memset(aux32, 0, 8*sizeof(int32_t)); memset(aux32, 0, 8*sizeof(int32_t));
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
q8 += 8; a += 8; q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
q8 += 8; a += 8; q8 += 8; a += 8;
int sc = x[i].scales[j]; for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
for (int l = 0; l < 8; ++l) aux32[l] += sc * aux16[l];
} }
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];

View file

@ -35,17 +35,23 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
// weight is represented as x = a * q // weight is represented as x = a * q
// 16 blocks of 16 elemenets each // 16 blocks of 16 elemenets each
// Effectively 3.4375 bits per weight // Effectively 3.4375 bits per weight
#ifdef GGML_QKK_64
typedef struct { typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits uint8_t qs[QK_K/4]; // quants - low 2 bits
#ifdef GGML_QKK_64 uint8_t scales[2];
int8_t scales[K_SCALE_SIZE];
#else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif
ggml_fp16_t d; // super-block scale ggml_fp16_t d; // super-block scale
} block_q3_K; } block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
#else
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[12]; // scales, quantized with 6 bits
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
#endif
// 4-bit quantization // 4-bit quantization
// 16 blocks of 32 elements each // 16 blocks of 32 elements each