From a0b8e9f3c90e482dbe0ca82f45f585de24f1ba67 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 29 May 2023 14:30:17 +0300 Subject: [PATCH] Adding Q4_K - scalar, AVX2, CUDA Performance is the same or perhaps very slightly better than Q4_0 on the CPU. On the GPU, single token prediction is ~10% better than Q4_0, batch mode (perplexity is about the same). --- examples/quantize/quantize.cpp | 1 + ggml-cuda.cu | 101 +++++++++++ ggml.c | 37 +++- ggml.h | 5 +- k_quants.c | 322 +++++++++++++++++++++++++++++++++ llama.cpp | 4 + llama.h | 1 + 7 files changed, 463 insertions(+), 8 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1f17a2acb..1e106ee10 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -13,6 +13,7 @@ static const std::map LLAMA_FTYPE_MAP = { {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, {"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K}, + {"q4_K", LLAMA_FTYPE_MOSTLY_Q4_K}, }; bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a6bfc3ca7..81dffe474 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -97,6 +97,14 @@ typedef struct { } block_q3_K; static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); + #define WARP_SIZE 32 #define CUDA_MUL_BLOCK_SIZE 256 @@ -261,6 +269,84 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs result = sum * scale; } +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + + //// assume 64 threads - this is very slightly better than the one below + //const int tid = threadIdx.x; + //const int il = tid/16; + //const int ir = tid%16; + //const int is = 2*il; + //const int n = 2; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +} + +static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q4_K * x = (const block_q4_K *) vx; + + // iqs is in 0...248 in steps of 8 => + const int j = iqs / 64; // j is in 0...3 + const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 + const int is = 2*j; // is is in 0...6 in steps of 2 + + const float * y = yy + 64*j + ir; + const uint8_t * q = x[ib].qs + 32*j + ir; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); + sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); + } + result = sum; + +} + static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const half * x = (const half *) vx; @@ -405,6 +491,11 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu dequantize_block_q3_K<<>>(vx, y); } +static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q4_K<<>>(vx, y); +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); @@ -451,6 +542,12 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<>>(vx, y, dst, ncols); } +static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<>>(vx, y, dst, ncols); +} + static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<32, 1, convert_f16><<>>(vx, y, k); @@ -478,6 +575,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q8_0_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; case GGML_TYPE_F16: return convert_fp16_to_fp32_cuda; default: @@ -499,6 +598,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t return dequantize_mul_mat_vec_q8_0_cuda; case GGML_TYPE_Q3_K: return dequantize_mul_mat_vec_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_mul_mat_vec_q4_K_cuda; case GGML_TYPE_F16: return convert_mul_mat_vec_f16_cuda; default: diff --git a/ggml.c b/ggml.c index 0e849a621..3cf5ec22d 100644 --- a/ggml.c +++ b/ggml.c @@ -1574,6 +1574,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, }, + [GGML_TYPE_Q4_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K, + .quantize_row_q = quantize_row_q4_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = ggml_vec_dot_q4_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, }; // For internal test use @@ -3454,12 +3462,13 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_Q3_K] = QK_K, + [GGML_TYPE_Q4_K] = QK_K, [GGML_TYPE_Q8_K] = QK_K, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 15, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 16, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3470,13 +3479,13 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), - [GGML_TYPE_Q3_K] = sizeof(block_q3_K), + [GGML_TYPE_Q4_K] = sizeof(block_q4_K), [GGML_TYPE_Q8_K] = sizeof(block_q8_K), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 15, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 16, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3489,12 +3498,13 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_Q3_K] = "q3_K", + [GGML_TYPE_Q4_K] = "q4_K", [GGML_TYPE_Q8_K] = "q8_K", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 15, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 16, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3505,13 +3515,13 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, - [GGML_TYPE_Q3_K] = true, + [GGML_TYPE_Q4_K] = true, [GGML_TYPE_Q8_K] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 15, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 16, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", @@ -3819,6 +3829,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7603,6 +7614,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7907,6 +7919,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8030,6 +8043,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: default: { GGML_ASSERT(false); @@ -10124,6 +10138,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -10308,6 +10323,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: default: { GGML_ASSERT(false); @@ -10474,6 +10490,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11021,6 +11038,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -11094,6 +11112,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -16104,6 +16123,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q3_K * block = (block_q3_K*)dst + start / QK_K; result = ggml_quantize_q3_K(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q4_K * block = (block_q4_K*)dst + start / QK_K; + result = ggml_quantize_q4_K(src + start, block, n, n, hist); + } break; default: assert(false); } diff --git a/ggml.h b/ggml.h index 7cf3b6b96..c38745bc1 100644 --- a/ggml.h +++ b/ggml.h @@ -243,10 +243,10 @@ extern "C" { GGML_TYPE_Q8_1 = 9, // k-quantizations GGML_TYPE_Q3_K = 10, - //GGML_TYPE_Q4_K = 11, + GGML_TYPE_Q4_K = 11, //GGML_TYPE_Q5_K = 12, //GGML_TYPE_Q6_K = 13, - GGML_TYPE_Q8_K = 11, + GGML_TYPE_Q8_K = 12, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -271,6 +271,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors GGML_FTYPE_MOSTLY_Q3_K = 10, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K = 11, // except 1d tensors }; // available tensor operations: diff --git a/k_quants.c b/k_quants.c index 86cc90616..e9f3160c1 100644 --- a/k_quants.c +++ b/k_quants.c @@ -192,6 +192,57 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * return 1/iscale; } +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + //========================= 3-bit (de)-quantization void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { @@ -338,6 +389,116 @@ size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n return (n/QK_K*sizeof(block_q3_K)); } +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = ggml_fp32_to_fp16(max_scale/63.f); + y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4); + } + + x += QK_K; + + } +} + +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * q = x[i].qs; + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + + } +} + +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q4_K * restrict y = vy; + quantize_row_q4_K_reference(x, y, k); +} + +size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + (void)hist; // TODO: collect histograms + for (int j = 0; j < nb; j += k) { + block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; + quantize_row_q4_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q4_K)); +} + // ====================== 6-bit (de)-quantization void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { @@ -484,6 +645,19 @@ static inline __m256i get_scale_shuffle_q3k(int i) { }; return _mm256_loadu_si256((const __m256i*)k_shuffle + i); } +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} #endif void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -670,3 +844,151 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #endif } + +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef z__ARM_NEON + + GGML_ASSERT(false); + +#elif defined __AVX2__ + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t utmp[4]; + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = _mm256_set_m128i(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + p16l = _mm256_madd_epi16(scale_l, p16l); + p16h = _mm256_madd_epi16(scale_h, p16h); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16l, p16h)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} diff --git a/llama.cpp b/llama.cpp index c0c072a9b..42aac3022 100644 --- a/llama.cpp +++ b/llama.cpp @@ -508,6 +508,7 @@ struct llama_file_loader { case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: break; default: { throw format("unrecognized tensor type %u\n", shard.type); @@ -584,6 +585,7 @@ struct llama_file_saver { case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: break; default: LLAMA_ASSERT(false); } @@ -901,6 +903,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; case LLAMA_FTYPE_MOSTLY_Q3_K: return "mostly Q3_K"; + case LLAMA_FTYPE_MOSTLY_Q4_K: return "mostly Q4_K"; default: return "unknown, may not work"; } } @@ -2067,6 +2070,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_Q3_K: quantized_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K: quantized_type = GGML_TYPE_Q4_K; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index de6671171..a645e3ad8 100644 --- a/llama.h +++ b/llama.h @@ -95,6 +95,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q3_K = 10,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K = 11,// except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params();