k_quants: swicth Q3_K to 4-bit scales when QK_K = 64
Otherwise there isn't much benefit from this quantization type. There is some very slight loss in accuracy, but we reduce size by ~7%. E.g., for OpenLLaMA-3B, Q3_K_S perplexity is 8.6131 with 8-bit scales and 8.6352 with 4-bit, while file size decreases from 1.53G to 1.44G.
This commit is contained in:
parent
88412a1aa0
commit
aeefd4e781
4 changed files with 86 additions and 50 deletions
23
ggml-cuda.cu
23
ggml-cuda.cu
|
@ -137,13 +137,13 @@ typedef struct {
|
||||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||||
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
int8_t scales[K_SCALE_SIZE]; // scales, quantized with 8 bits
|
uint8_t scales[2]; // scales, quantized with 8 bits
|
||||||
#else
|
#else
|
||||||
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
||||||
#endif
|
#endif
|
||||||
half d; // super-block scale
|
half d; // super-block scale
|
||||||
} block_q3_K;
|
} block_q3_K;
|
||||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
||||||
|
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -448,8 +448,13 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||||
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
||||||
const float d = (float)x[i].d;
|
const float d = (float)x[i].d;
|
||||||
|
|
||||||
y[ 0] = d * x[i].scales[is+0] * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
if (is == 0) {
|
||||||
y[32] = d * x[i].scales[is+2] * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
||||||
|
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
||||||
|
} else {
|
||||||
|
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
||||||
|
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -776,7 +781,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
|
|
||||||
const float * y = yy + i * QK_K + offset;
|
const float * y = yy + i * QK_K + offset;
|
||||||
const uint8_t * q = x[i].qs + offset;
|
const uint8_t * q = x[i].qs + offset;
|
||||||
const int8_t * s = x[i].scales;
|
const uint8_t * s = x[i].scales;
|
||||||
|
|
||||||
const float dall = (float)x[i].d;
|
const float dall = (float)x[i].d;
|
||||||
|
|
||||||
|
@ -784,10 +789,10 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||||
const uint8_t hl = x[i].hmask[im+l] >> in;
|
const uint8_t hl = x[i].hmask[im+l] >> in;
|
||||||
const uint8_t ql = q[l];
|
const uint8_t ql = q[l];
|
||||||
sum += y[l+ 0] * dall * s[0] * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
|
sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
|
||||||
+ y[l+16] * dall * s[1] * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
|
+ y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
|
||||||
+ y[l+32] * dall * s[2] * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
|
+ y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
|
||||||
+ y[l+48] * dall * s[3] * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
|
+ y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
|
||||||
}
|
}
|
||||||
tmp += sum;
|
tmp += sum;
|
||||||
}
|
}
|
||||||
|
|
|
@ -799,7 +799,7 @@ typedef struct {
|
||||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||||
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
int8_t scales[K_SCALE_SIZE];
|
uint8_t scales[2];
|
||||||
#else
|
#else
|
||||||
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
||||||
#endif
|
#endif
|
||||||
|
@ -970,10 +970,10 @@ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, i
|
||||||
device const uint8_t * q = x[i].qs;
|
device const uint8_t * q = x[i].qs;
|
||||||
device const uint8_t * hm = x[i].hmask;
|
device const uint8_t * hm = x[i].hmask;
|
||||||
|
|
||||||
const float d1 = d_all * x[i].scales[0];
|
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
||||||
const float d2 = d_all * x[i].scales[1];
|
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
||||||
const float d3 = d_all * x[i].scales[2];
|
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
||||||
const float d4 = d_all * x[i].scales[3];
|
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
||||||
|
|
||||||
for (int l = 0; l < 8; ++l) {
|
for (int l = 0; l < 8; ++l) {
|
||||||
uint8_t h = hm[l];
|
uint8_t h = hm[l];
|
||||||
|
@ -1417,10 +1417,10 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
device const uint8_t * h = x[i].hmask + in;
|
device const uint8_t * h = x[i].hmask + in;
|
||||||
device const float * y = yy + i * QK_K + il;
|
device const float * y = yy + i * QK_K + il;
|
||||||
|
|
||||||
const float d1 = d_all * x[i].scales[0];
|
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
||||||
const float d2 = d_all * x[i].scales[1];
|
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
||||||
const float d3 = d_all * x[i].scales[2];
|
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
||||||
const float d4 = d_all * x[i].scales[3];
|
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
||||||
|
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
const uint8_t hm = h[l] >> im;
|
const uint8_t hm = h[l] >> im;
|
||||||
|
|
77
k_quants.c
77
k_quants.c
|
@ -469,21 +469,24 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (max_scale) {
|
if (max_scale) {
|
||||||
float iscale = -128.f/max_scale;
|
float iscale = -8.f/max_scale;
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; j+=2) {
|
||||||
int l = nearest_int(iscale*scales[j]);
|
int l1 = nearest_int(iscale*scales[j]);
|
||||||
l = MAX(-128, MIN(127, l));
|
l1 = 8 + MAX(-8, MIN(7, l1));
|
||||||
y[i].scales[j] = l;
|
int l2 = nearest_int(iscale*scales[j+1]);
|
||||||
|
l2 = 8 + MAX(-8, MIN(7, l2));
|
||||||
|
y[i].scales[j/2] = l1 | (l2 << 4);
|
||||||
}
|
}
|
||||||
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
||||||
} else {
|
} else {
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; j+=2) {
|
||||||
y[i].scales[j] = 0;
|
y[i].scales[j/2] = 0;
|
||||||
}
|
}
|
||||||
y[i].d = ggml_fp32_to_fp16(0.f);
|
y[i].d = ggml_fp32_to_fp16(0.f);
|
||||||
}
|
}
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
|
int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
|
||||||
|
float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
|
||||||
if (!d) {
|
if (!d) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -587,10 +590,10 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
|
||||||
const uint8_t * restrict q = x[i].qs;
|
const uint8_t * restrict q = x[i].qs;
|
||||||
const uint8_t * restrict hm = x[i].hmask;
|
const uint8_t * restrict hm = x[i].hmask;
|
||||||
|
|
||||||
const float d1 = d_all * x[i].scales[0];
|
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
||||||
const float d2 = d_all * x[i].scales[1];
|
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
||||||
const float d3 = d_all * x[i].scales[2];
|
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
||||||
const float d4 = d_all * x[i].scales[3];
|
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
||||||
|
|
||||||
for (int l=0; l<8; ++l) {
|
for (int l=0; l<8; ++l) {
|
||||||
uint8_t h = hm[l];
|
uint8_t h = hm[l];
|
||||||
|
@ -1889,6 +1892,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
int8x16x4_t q3bytes;
|
int8x16x4_t q3bytes;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
int8_t * scales = (int8_t *)aux16;
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
@ -1899,8 +1905,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
||||||
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
|
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
|
||||||
|
|
||||||
const int8_t * restrict scales = x[i].scales;
|
const uint16_t a = *(const uint16_t *)x[i].scales;
|
||||||
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[1] * y[i].bsums[1] + scales[2] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
aux16[0] = a & 0x0f0f;
|
||||||
|
aux16[1] = (a >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < 4; ++j) scales[j] -= 8;
|
||||||
|
|
||||||
|
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
||||||
|
|
||||||
const float d = y[i].d * (float)x[i].d;
|
const float d = y[i].d * (float)x[i].d;
|
||||||
|
|
||||||
|
@ -1917,8 +1928,8 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[1];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[2];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
||||||
#else
|
#else
|
||||||
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
@ -1929,7 +1940,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[1] + vaddvq_s16(p2) * scales[2] + vaddvq_s16(p3) * scales[3];
|
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
sum += d * isum;
|
sum += d * isum;
|
||||||
|
@ -1942,11 +1953,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
const __m256i m3 = _mm256_set1_epi8(3);
|
const __m256i m3 = _mm256_set1_epi8(3);
|
||||||
const __m256i m1 = _mm256_set1_epi8(1);
|
const __m256i m1 = _mm256_set1_epi8(1);
|
||||||
|
const __m256i m8 = _mm256_set1_epi16(8);
|
||||||
|
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
uint64_t aux64;
|
uint64_t aux64;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const int8_t * aux8 = (const int8_t *)aux16;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
@ -1954,14 +1969,18 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const uint8_t * restrict q3 = x[i].qs;
|
const uint8_t * restrict q3 = x[i].qs;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
const uint16_t a = *(const uint16_t *)x[i].scales;
|
||||||
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
aux16[0] = a & 0x0f0f;
|
||||||
|
aux16[1] = (a >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
||||||
|
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
||||||
|
|
||||||
// Set up scales
|
|
||||||
memcpy(&aux64, x[i].hmask, 8);
|
memcpy(&aux64, x[i].hmask, 8);
|
||||||
|
|
||||||
__m256i q3h_0 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 3, aux64 >> 2), _mm_set_epi64x(aux64 >> 1, aux64 >> 0));
|
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
||||||
__m256i q3h_1 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 7, aux64 >> 6), _mm_set_epi64x(aux64 >> 5, aux64 >> 4));
|
__m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
|
||||||
|
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
||||||
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
||||||
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
||||||
|
|
||||||
|
@ -1969,8 +1988,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
||||||
|
|
||||||
// prepare low and high bits
|
// prepare low and high bits
|
||||||
const __m256i q3l_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits), m3);
|
const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
|
||||||
const __m256i q3l_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 6), _mm_srli_epi16(q3bits, 4)), m3);
|
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
||||||
|
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
||||||
|
|
||||||
// load Q8 quants
|
// load Q8 quants
|
||||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
@ -2007,6 +2027,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
int16_t aux16[8];
|
int16_t aux16[8];
|
||||||
float sums [8];
|
float sums [8];
|
||||||
int32_t aux32[8];
|
int32_t aux32[8];
|
||||||
|
int32_t scales[4];
|
||||||
memset(sums, 0, 8*sizeof(float));
|
memset(sums, 0, 8*sizeof(float));
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
@ -2026,14 +2047,18 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
|
a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scales[0] = (x[i].scales[0] & 0xF) - 8;
|
||||||
|
scales[1] = (x[i].scales[0] >> 4) - 8;
|
||||||
|
scales[2] = (x[i].scales[1] & 0xF) - 8;
|
||||||
|
scales[3] = (x[i].scales[1] >> 4) - 8;
|
||||||
|
|
||||||
memset(aux32, 0, 8*sizeof(int32_t));
|
memset(aux32, 0, 8*sizeof(int32_t));
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||||
q8 += 8; a += 8;
|
q8 += 8; a += 8;
|
||||||
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
||||||
q8 += 8; a += 8;
|
q8 += 8; a += 8;
|
||||||
int sc = x[i].scales[j];
|
for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
|
||||||
for (int l = 0; l < 8; ++l) aux32[l] += sc * aux16[l];
|
|
||||||
}
|
}
|
||||||
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
||||||
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
||||||
|
|
18
k_quants.h
18
k_quants.h
|
@ -35,17 +35,23 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
|
||||||
// weight is represented as x = a * q
|
// weight is represented as x = a * q
|
||||||
// 16 blocks of 16 elemenets each
|
// 16 blocks of 16 elemenets each
|
||||||
// Effectively 3.4375 bits per weight
|
// Effectively 3.4375 bits per weight
|
||||||
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||||
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||||
#ifdef GGML_QKK_64
|
uint8_t scales[2];
|
||||||
int8_t scales[K_SCALE_SIZE];
|
|
||||||
#else
|
|
||||||
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
|
||||||
#endif
|
|
||||||
ggml_fp16_t d; // super-block scale
|
ggml_fp16_t d; // super-block scale
|
||||||
} block_q3_K;
|
} block_q3_K;
|
||||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
|
||||||
|
#else
|
||||||
|
typedef struct {
|
||||||
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||||
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||||
|
uint8_t scales[12]; // scales, quantized with 6 bits
|
||||||
|
ggml_fp16_t d; // super-block scale
|
||||||
|
} block_q3_K;
|
||||||
|
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
|
||||||
|
#endif
|
||||||
|
|
||||||
// 4-bit quantization
|
// 4-bit quantization
|
||||||
// 16 blocks of 32 elements each
|
// 16 blocks of 32 elements each
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue