diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 22f13d306..840c0a863 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -147,10 +147,11 @@ typedef struct { #ifdef GGML_QKK_64 typedef struct { - half d[2*QK_K/32]; // super-block scales/mins + half d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); #else typedef struct { half d; // super-block scale for quantized scales @@ -503,8 +504,10 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const int tid = threadIdx.x; const uint8_t * q = x[i].qs; float * y = yy + i*QK_K; - y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1]; - y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3]; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); #endif } @@ -874,20 +877,25 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float #else const int step = tid * K_QUANTS_PER_ITERATION; + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + float tmp = 0; for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const uint8_t * q = x[i].qs + step; const float * y = yy + i*QK_K + step; - const half2 * d = (const half2 *)x[i].d; - float2 df1 = __half22float2(d[0]); - float2 df2 = __half22float2(d[1]); + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; float sum = 0.f; for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) - + y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y) - + y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y) - + y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y); + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); } tmp += sum; } diff --git a/k_quants.c b/k_quants.c index c2c96d4a5..3d395e59f 100644 --- a/k_quants.c +++ b/k_quants.c @@ -635,14 +635,11 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict const int nb = k / QK_K; uint8_t L[QK_K]; -#if QK_K == 256 float mins[QK_K/32]; float scales[QK_K/32]; -#endif for (int i = 0; i < nb; i++) { -#if QK_K == 256 float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -657,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } +#if QK_K == 256 float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; float inv_min = max_min > 0 ? 63.f/max_min : 0.f; for (int j = 0; j < QK_K/32; ++j) { @@ -689,23 +687,37 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } #else - for (int j = 0; j < QK_K/32; ++j) { - float min; - float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5); - y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); - y[i].d[2*j+1] = ggml_fp32_to_fp16(min); - } + const float s_factor = 15.f; + float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; + float inv_min = max_min > 0 ? s_factor/max_min : 0.f; + int d1 = nearest_int(inv_scale*scales[0]); + int m1 = nearest_int(inv_min*mins[0]); + int d2 = nearest_int(inv_scale*scales[1]); + int m2 = nearest_int(inv_min*mins[1]); + y[i].scales[0] = d1 | (m1 << 4); + y[i].scales[1] = d2 | (m2 << 4); + y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor); + y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor); + float sumlx = 0; + int suml2 = 0; for (int j = 0; j < QK_K/32; ++j) { - const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); + const uint8_t sd = y[i].scales[j] & 0xF; + const uint8_t sm = y[i].scales[j] >> 4; + const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd; if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); + const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm; for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + dm)/d); + int l = nearest_int((x[32*j + ii] + m)/d); l = MAX(0, MIN(15, l)); L[32*j + ii] = l; + sumlx += (x[32*j + ii] + m)*l*sd; + suml2 += l*l*sd*sd; } } + if (suml2) { + y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2); + } #endif uint8_t * q = y[i].qs; for (int j = 0; j < QK_K; j += 64) { @@ -743,8 +755,10 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int q += 32; is += 2; } #else - float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); - float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); + const float dall = ggml_fp16_to_fp32(x[i].d[0]); + const float mall = ggml_fp16_to_fp32(x[i].d[1]); + const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); + const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); for (int l = 0; l < 32; ++l) { y[l+ 0] = d1 * (q[l] & 0xF) - m1; y[l+32] = d2 * (q[l] >> 4) - m2; @@ -1953,7 +1967,6 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const __m256i m3 = _mm256_set1_epi8(3); const __m256i m1 = _mm256_set1_epi8(1); - const __m256i m8 = _mm256_set1_epi16(8); __m256 acc = _mm256_setzero_ps(); @@ -2182,7 +2195,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { -#if QK_K == 256 const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); @@ -2192,10 +2204,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; -#else - // TODO - const float d = 0; const float dmin = 0; -#endif const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -2326,15 +2334,22 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sum_mins = 0.f; + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const float32x4_t dsc = vcvt_f32_f16(vld1_f16(x[i].d)); - float summ = vgetq_lane_f32(dsc, 1) * (y[i].bsums[0] + y[i].bsums[1]) - + vgetq_lane_f32(dsc, 3) * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * summ; + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * (float)x[i].d[1] * summi; + + const float d = y[i].d * (float)x[i].d[0]; const uint8x16x2_t q4bits = vld1q_u8_x2(q4); @@ -2344,13 +2359,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - const float sumf1 = vaddvq_s32(p1) * vgetq_lane_f32(dsc, 0); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); - const float sumf2 = vaddvq_s32(p2) * vgetq_lane_f32(dsc, 2); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; #else q8bytes = vld1q_s8_x4(q8); @@ -2360,7 +2375,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - float sumf1 = vaddvq_s16(vaddq_s16(p0, p1)) * vgetq_lane_f32(dsc, 0); + int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); @@ -2368,10 +2383,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); - float sumf2 = vaddvq_s16(vaddq_s16(p2, p3)) * vgetq_lane_f32(dsc, 2); + int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; #endif - sumf += y[i].d * (sumf1 + sumf2); + sumf += d * (sumi1 + sumi2); } @@ -2380,16 +2395,25 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m1 = _mm256_set1_epi16(1); __m256 acc = _mm256_setzero_ps(); float summs = 0; + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + for (int i = 0; i < nb; ++i) { - summs += y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -2404,11 +2428,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - const __m256i p32l = _mm256_madd_epi16(m1, p16l); - const __m256i p32h = _mm256_madd_epi16(m1, p16h); + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[0])), _mm256_cvtepi32_ps(p32l), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[2])), _mm256_cvtepi32_ps(p32h), acc); + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); } @@ -2421,6 +2445,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sums [8]; memset(sums, 0, 8*sizeof(float)); + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; @@ -2429,16 +2456,21 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; - sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); for (int j = 0; j < QK_K/32; ++j) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; q8 += 16; a += 16; for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; q8 += 16; a += 16; - for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]); + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); } } for (int l = 0; l < 8; ++l) sumf += sums[l]; diff --git a/k_quants.h b/k_quants.h index 943b62ee5..6256ae167 100644 --- a/k_quants.h +++ b/k_quants.h @@ -59,10 +59,11 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + // Effectively 4.5 bits per weight #ifdef GGML_QKK_64 typedef struct { - ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins + ggml_fp16_t d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); #else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales @@ -70,8 +71,8 @@ typedef struct { uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -#endif static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif // 5-bit quantization // 16 blocks of 32 elements each diff --git a/llama.cpp b/llama.cpp index 78eb8427c..c41c2a8a3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2533,12 +2533,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && + (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; } else if (tensor.name.find("attention.wo.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;