k_quants: switch Q4_K to 4-bit scales when QK_K = 64
Here the loss in accuracy is greater than for Q3_K, but the Q4_K points still move further to the left on the perplexity vs size curve.
This commit is contained in:
parent
aeefd4e781
commit
ce19b965f0
4 changed files with 98 additions and 54 deletions
30
ggml-cuda.cu
30
ggml-cuda.cu
|
@ -147,10 +147,11 @@ typedef struct {
|
||||||
|
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d[2*QK_K/32]; // super-block scales/mins
|
half d[2]; // super-block scales/mins
|
||||||
|
uint8_t scales[2]; // 4-bit block scales/mins
|
||||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
} block_q4_K;
|
} block_q4_K;
|
||||||
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding");
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
|
||||||
#else
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
|
@ -503,8 +504,10 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const uint8_t * q = x[i].qs;
|
const uint8_t * q = x[i].qs;
|
||||||
float * y = yy + i*QK_K;
|
float * y = yy + i*QK_K;
|
||||||
y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1];
|
const float d = (float)x[i].d[0];
|
||||||
y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3];
|
const float m = (float)x[i].d[1];
|
||||||
|
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
||||||
|
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -874,20 +877,25 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
#else
|
#else
|
||||||
const int step = tid * K_QUANTS_PER_ITERATION;
|
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const uint8_t * s = (const uint8_t *)aux16;
|
||||||
|
|
||||||
float tmp = 0;
|
float tmp = 0;
|
||||||
|
|
||||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||||
const uint8_t * q = x[i].qs + step;
|
const uint8_t * q = x[i].qs + step;
|
||||||
const float * y = yy + i*QK_K + step;
|
const float * y = yy + i*QK_K + step;
|
||||||
const half2 * d = (const half2 *)x[i].d;
|
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||||
float2 df1 = __half22float2(d[0]);
|
aux16[0] = a[0] & 0x0f0f;
|
||||||
float2 df2 = __half22float2(d[1]);
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||||
|
const float d = (float)x[i].d[0];
|
||||||
|
const float m = (float)x[i].d[1];
|
||||||
float sum = 0.f;
|
float sum = 0.f;
|
||||||
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||||
sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y)
|
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
|
||||||
+ y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y)
|
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
|
||||||
+ y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y)
|
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
|
||||||
+ y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y);
|
+ y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
|
||||||
}
|
}
|
||||||
tmp += sum;
|
tmp += sum;
|
||||||
}
|
}
|
||||||
|
|
112
k_quants.c
112
k_quants.c
|
@ -635,14 +635,11 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
#if QK_K == 256
|
|
||||||
float mins[QK_K/32];
|
float mins[QK_K/32];
|
||||||
float scales[QK_K/32];
|
float scales[QK_K/32];
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
|
@ -657,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
||||||
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
|
@ -689,23 +687,37 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
const float s_factor = 15.f;
|
||||||
float min;
|
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
|
||||||
float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5);
|
float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
|
||||||
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale);
|
int d1 = nearest_int(inv_scale*scales[0]);
|
||||||
y[i].d[2*j+1] = ggml_fp32_to_fp16(min);
|
int m1 = nearest_int(inv_min*mins[0]);
|
||||||
}
|
int d2 = nearest_int(inv_scale*scales[1]);
|
||||||
|
int m2 = nearest_int(inv_min*mins[1]);
|
||||||
|
y[i].scales[0] = d1 | (m1 << 4);
|
||||||
|
y[i].scales[1] = d2 | (m2 << 4);
|
||||||
|
y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
|
||||||
|
y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
|
||||||
|
|
||||||
|
float sumlx = 0;
|
||||||
|
int suml2 = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]);
|
const uint8_t sd = y[i].scales[j] & 0xF;
|
||||||
|
const uint8_t sm = y[i].scales[j] >> 4;
|
||||||
|
const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
|
||||||
if (!d) continue;
|
if (!d) continue;
|
||||||
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]);
|
const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
|
||||||
for (int ii = 0; ii < 32; ++ii) {
|
for (int ii = 0; ii < 32; ++ii) {
|
||||||
int l = nearest_int((x[32*j + ii] + dm)/d);
|
int l = nearest_int((x[32*j + ii] + m)/d);
|
||||||
l = MAX(0, MIN(15, l));
|
l = MAX(0, MIN(15, l));
|
||||||
L[32*j + ii] = l;
|
L[32*j + ii] = l;
|
||||||
|
sumlx += (x[32*j + ii] + m)*l*sd;
|
||||||
|
suml2 += l*l*sd*sd;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (suml2) {
|
||||||
|
y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
uint8_t * q = y[i].qs;
|
uint8_t * q = y[i].qs;
|
||||||
for (int j = 0; j < QK_K; j += 64) {
|
for (int j = 0; j < QK_K; j += 64) {
|
||||||
|
@ -743,8 +755,10 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
||||||
q += 32; is += 2;
|
q += 32; is += 2;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]);
|
const float dall = ggml_fp16_to_fp32(x[i].d[0]);
|
||||||
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]);
|
const float mall = ggml_fp16_to_fp32(x[i].d[1]);
|
||||||
|
const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
|
||||||
|
const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
|
||||||
for (int l = 0; l < 32; ++l) {
|
for (int l = 0; l < 32; ++l) {
|
||||||
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
||||||
y[l+32] = d2 * (q[l] >> 4) - m2;
|
y[l+32] = d2 * (q[l] >> 4) - m2;
|
||||||
|
@ -1953,7 +1967,6 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
const __m256i m3 = _mm256_set1_epi8(3);
|
const __m256i m3 = _mm256_set1_epi8(3);
|
||||||
const __m256i m1 = _mm256_set1_epi8(1);
|
const __m256i m1 = _mm256_set1_epi8(1);
|
||||||
const __m256i m8 = _mm256_set1_epi16(8);
|
|
||||||
|
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
@ -2182,7 +2195,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||||
|
|
||||||
|
@ -2192,10 +2204,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||||
utmp[2] = uaux;
|
utmp[2] = uaux;
|
||||||
utmp[0] &= kmask1;
|
utmp[0] &= kmask1;
|
||||||
#else
|
|
||||||
// TODO
|
|
||||||
const float d = 0; const float dmin = 0;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const uint8_t * restrict q4 = x[i].qs;
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
@ -2326,15 +2334,22 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
float sum_mins = 0.f;
|
float sum_mins = 0.f;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const uint8_t * restrict scales = (const uint8_t *)aux16;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
const uint8_t * restrict q4 = x[i].qs;
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
const float32x4_t dsc = vcvt_f32_f16(vld1_f16(x[i].d));
|
const uint16_t * restrict a = (const uint16_t *)x[i].scales;
|
||||||
float summ = vgetq_lane_f32(dsc, 1) * (y[i].bsums[0] + y[i].bsums[1])
|
aux16[0] = a[0] & 0x0f0f;
|
||||||
+ vgetq_lane_f32(dsc, 3) * (y[i].bsums[2] + y[i].bsums[3]);
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||||
sum_mins += y[i].d * summ;
|
|
||||||
|
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
|
||||||
|
sum_mins += y[i].d * (float)x[i].d[1] * summi;
|
||||||
|
|
||||||
|
const float d = y[i].d * (float)x[i].d[0];
|
||||||
|
|
||||||
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
|
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
|
||||||
|
|
||||||
|
@ -2344,13 +2359,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
|
||||||
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
const float sumf1 = vaddvq_s32(p1) * vgetq_lane_f32(dsc, 0);
|
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
||||||
|
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
||||||
const float sumf2 = vaddvq_s32(p2) * vgetq_lane_f32(dsc, 2);
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
||||||
|
|
||||||
#else
|
#else
|
||||||
q8bytes = vld1q_s8_x4(q8);
|
q8bytes = vld1q_s8_x4(q8);
|
||||||
|
@ -2360,7 +2375,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
float sumf1 = vaddvq_s16(vaddq_s16(p0, p1)) * vgetq_lane_f32(dsc, 0);
|
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
||||||
|
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
@ -2368,10 +2383,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
||||||
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
||||||
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
||||||
float sumf2 = vaddvq_s16(vaddq_s16(p2, p3)) * vgetq_lane_f32(dsc, 2);
|
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
sumf += y[i].d * (sumf1 + sumf2);
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2380,16 +2395,25 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
#elif defined __AVX2__
|
#elif defined __AVX2__
|
||||||
|
|
||||||
const __m256i m4 = _mm256_set1_epi8(0xF);
|
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||||
const __m256i m1 = _mm256_set1_epi16(1);
|
|
||||||
|
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
float summs = 0;
|
float summs = 0;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const uint8_t * scales = (const uint8_t *)aux16;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
summs += y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
|
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
||||||
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
|
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
||||||
|
const __m256 vd = _mm256_set1_ps(d);
|
||||||
|
|
||||||
|
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||||
|
aux16[0] = a[0] & 0x0f0f;
|
||||||
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||||
|
|
||||||
const uint8_t * restrict q4 = x[i].qs;
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
@ -2404,11 +2428,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
||||||
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||||
|
|
||||||
const __m256i p32l = _mm256_madd_epi16(m1, p16l);
|
const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
|
||||||
const __m256i p32h = _mm256_madd_epi16(m1, p16h);
|
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
|
||||||
|
|
||||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[0])), _mm256_cvtepi32_ps(p32l), acc);
|
const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
|
||||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[2])), _mm256_cvtepi32_ps(p32h), acc);
|
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2421,6 +2445,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
float sums [8];
|
float sums [8];
|
||||||
memset(sums, 0, 8*sizeof(float));
|
memset(sums, 0, 8*sizeof(float));
|
||||||
|
|
||||||
|
uint16_t s16[2];
|
||||||
|
const uint8_t * restrict scales = (const uint8_t *)s16;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
const uint8_t * restrict q4 = x[i].qs;
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
@ -2429,16 +2456,21 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
|
for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
|
||||||
for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
|
for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
|
||||||
|
|
||||||
sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
|
const uint16_t * restrict b = (const uint16_t *)x[i].scales;
|
||||||
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
|
s16[0] = b[0] & 0x0f0f;
|
||||||
|
s16[1] = (b[0] >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]);
|
|
||||||
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
|
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
|
||||||
q8 += 16; a += 16;
|
q8 += 16; a += 16;
|
||||||
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
|
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
|
||||||
q8 += 16; a += 16;
|
q8 += 16; a += 16;
|
||||||
for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]);
|
const float dl = d * scales[j];
|
||||||
|
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||||
|
|
|
@ -59,10 +59,11 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 +
|
||||||
// Effectively 4.5 bits per weight
|
// Effectively 4.5 bits per weight
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins
|
ggml_fp16_t d[2]; // super-block scales/mins
|
||||||
|
uint8_t scales[2]; // 4-bit block scales/mins
|
||||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
} block_q4_K;
|
} block_q4_K;
|
||||||
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding");
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
|
||||||
#else
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
ggml_fp16_t d; // super-block scale for quantized scales
|
ggml_fp16_t d; // super-block scale for quantized scales
|
||||||
|
@ -70,8 +71,8 @@ typedef struct {
|
||||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
} block_q4_K;
|
} block_q4_K;
|
||||||
#endif
|
|
||||||
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
|
||||||
|
#endif
|
||||||
|
|
||||||
// 5-bit quantization
|
// 5-bit quantization
|
||||||
// 16 blocks of 32 elements each
|
// 16 blocks of 32 elements each
|
||||||
|
|
|
@ -2533,12 +2533,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||||
use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
|
use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
|
||||||
|
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
|
||||||
|
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
|
||||||
++i_attention_wv;
|
++i_attention_wv;
|
||||||
} else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) {
|
} else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||||
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
|
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
|
||||||
|
//else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
|
||||||
++i_feed_forward_w2;
|
++i_feed_forward_w2;
|
||||||
} else if (tensor.name.find("attention.wo.weight") != std::string::npos) {
|
} else if (tensor.name.find("attention.wo.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue