From cf221afb555a945be4d1e4153e38808a9d21a4cb Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 29 May 2023 16:02:54 +0300 Subject: [PATCH] Adding Q6_K - scalar, AVX2, CUDA Performance is ~40% lower compared to Q4_K on the CPU. This is to be expected, considering that we are memory bound on the CPU and the 6-bit model is ~44% larger than the 4-bit. On the GPU, single token prediction is ~6% lower than Q4_0, batch mode (perplexity) is even closer (but still slower). --- examples/quantize/quantize.cpp | 1 + ggml-cuda.cu | 75 ++++++++++++ ggml.c | 35 +++++- ggml.h | 5 +- k_quants.c | 203 ++++++++++++++++++++++++++++++++- llama.cpp | 4 + llama.h | 1 + 7 files changed, 317 insertions(+), 7 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1e106ee10..b0b4b0b71 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -14,6 +14,7 @@ static const std::map LLAMA_FTYPE_MAP = { {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, {"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K}, {"q4_K", LLAMA_FTYPE_MOSTLY_Q4_K}, + {"q6_K", LLAMA_FTYPE_MOSTLY_Q6_K}, }; bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 81dffe474..51c976866 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -105,6 +105,14 @@ typedef struct { } block_q4_K; static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + #define WARP_SIZE 32 #define CUDA_MUL_BLOCK_SIZE 256 @@ -347,6 +355,58 @@ static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs } +static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + float * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + +static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q6_K * x = (const block_q6_K *) vx; + + const int ip = iqs / 128; // 0 or 1 + const int il = (iqs - 128*ip)/8; // 0...15 + const int is = 8*ip; + + const float * y = yy + 128*ip + il; + + const float d = x[ib].d; + + const uint8_t * ql = x[ib].ql + 64*ip + il; + const uint8_t * qh = x[ib].qh + 32*ip + il; + const int8_t * sc = x[ib].scales + is; + + result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) + + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32) + + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32) + + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32) + + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32) + + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32) + + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32) + + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32); + +} + static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const half * x = (const half *) vx; @@ -496,6 +556,11 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu dequantize_block_q4_K<<>>(vx, y); } +static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q6_K<<>>(vx, y); +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); @@ -548,6 +613,12 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<>>(vx, y, dst, ncols); } +static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<>>(vx, y, dst, ncols); +} + static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<32, 1, convert_f16><<>>(vx, y, k); @@ -577,6 +648,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; case GGML_TYPE_F16: return convert_fp16_to_fp32_cuda; default: @@ -600,6 +673,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t return dequantize_mul_mat_vec_q3_K_cuda; case GGML_TYPE_Q4_K: return dequantize_mul_mat_vec_q4_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_mul_mat_vec_q6_K_cuda; case GGML_TYPE_F16: return convert_mul_mat_vec_f16_cuda; default: diff --git a/ggml.c b/ggml.c index 3cf5ec22d..dce6b435c 100644 --- a/ggml.c +++ b/ggml.c @@ -1582,6 +1582,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, }, + [GGML_TYPE_Q6_K] = { + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K, + .quantize_row_q = quantize_row_q6_K, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = ggml_vec_dot_q6_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, }; // For internal test use @@ -3463,12 +3471,13 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_Q3_K] = QK_K, [GGML_TYPE_Q4_K] = QK_K, + [GGML_TYPE_Q6_K] = QK_K, [GGML_TYPE_Q8_K] = QK_K, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 16, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 17, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3480,12 +3489,13 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), [GGML_TYPE_Q4_K] = sizeof(block_q4_K), + [GGML_TYPE_Q6_K] = sizeof(block_q6_K), [GGML_TYPE_Q8_K] = sizeof(block_q8_K), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 16, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 17, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3499,12 +3509,13 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_Q3_K] = "q3_K", [GGML_TYPE_Q4_K] = "q4_K", + [GGML_TYPE_Q6_K] = "q6_K", [GGML_TYPE_Q8_K] = "q8_K", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 16, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 17, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3516,12 +3527,13 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, [GGML_TYPE_Q4_K] = true, + [GGML_TYPE_Q6_K] = true, [GGML_TYPE_Q8_K] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 16, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 17, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", @@ -3830,6 +3842,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7615,6 +7628,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7920,6 +7934,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8044,6 +8059,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: default: { GGML_ASSERT(false); @@ -10139,6 +10155,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -10324,6 +10341,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: default: { GGML_ASSERT(false); @@ -10491,6 +10509,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11039,6 +11058,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -11113,6 +11133,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_1: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -16129,6 +16150,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_K * block = (block_q4_K*)dst + start / QK_K; result = ggml_quantize_q4_K(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q6_K * block = (block_q6_K*)dst + start / QK_K; + result = ggml_quantize_q6_K(src + start, block, n, n, hist); + } break; default: assert(false); } diff --git a/ggml.h b/ggml.h index c38745bc1..a3d440a96 100644 --- a/ggml.h +++ b/ggml.h @@ -245,8 +245,8 @@ extern "C" { GGML_TYPE_Q3_K = 10, GGML_TYPE_Q4_K = 11, //GGML_TYPE_Q5_K = 12, - //GGML_TYPE_Q6_K = 13, - GGML_TYPE_Q8_K = 12, + GGML_TYPE_Q6_K = 12, + GGML_TYPE_Q8_K = 13, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -272,6 +272,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors GGML_FTYPE_MOSTLY_Q3_K = 10, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_K = 11, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K = 12, // except 1d tensors }; // available tensor operations: diff --git a/k_quants.c b/k_quants.c index e9f3160c1..2000f42c7 100644 --- a/k_quants.c +++ b/k_quants.c @@ -565,6 +565,58 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict } } +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + + } +} + +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q6_K * restrict y = vy; + quantize_row_q6_K_reference(x, y, k); +} + +size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + (void)hist; // TODO + + for (int j = 0; j < nb; j += k) { + block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; + quantize_row_q6_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q6_K)); +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -635,7 +687,7 @@ static inline float hsum_float_8(const __m256 x) { return _mm_cvtss_f32(res); } -// shuffle to pick the required scales in dot products +// shuffles to pick the required scales in dot products static inline __m256i get_scale_shuffle_q3k(int i) { static const uint8_t k_shuffle[128] = { 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, @@ -658,6 +710,19 @@ static inline __m256i get_scale_shuffle_k4(int i) { }; return _mm256_loadu_si256((const __m256i*)k_shuffle + i); } +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} #endif void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -992,3 +1057,139 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; #endif } + +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef z__ARM_NEON + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + + diff --git a/llama.cpp b/llama.cpp index 42aac3022..b9c1ccd54 100644 --- a/llama.cpp +++ b/llama.cpp @@ -509,6 +509,7 @@ struct llama_file_loader { case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: break; default: { throw format("unrecognized tensor type %u\n", shard.type); @@ -586,6 +587,7 @@ struct llama_file_saver { case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: break; default: LLAMA_ASSERT(false); } @@ -904,6 +906,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; case LLAMA_FTYPE_MOSTLY_Q3_K: return "mostly Q3_K"; case LLAMA_FTYPE_MOSTLY_Q4_K: return "mostly Q4_K"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; default: return "unknown, may not work"; } } @@ -2071,6 +2074,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_Q3_K: quantized_type = GGML_TYPE_Q3_K; break; case LLAMA_FTYPE_MOSTLY_Q4_K: quantized_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index a645e3ad8..cca5eb66f 100644 --- a/llama.h +++ b/llama.h @@ -96,6 +96,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q3_K = 10,// except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_K = 11,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 12,// except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params();