From c93cce3a450f23c1678135929caf0f177052f132 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 28 May 2023 21:38:00 +0300 Subject: [PATCH] Q3_K now working on CUDA and AVX2/scalar CUDA is not ideal - ~50% slower than Q4_0 for single token prediction, about the same in batch mode (perplexity). CPU single token is ~55 ms (on Ryzen 7950X). --- Makefile | 4 +- examples/quantize-stats/quantize-stats.cpp | 5 +- examples/quantize/quantize.cpp | 1 + ggml-cuda.cu | 142 ++++++++++++++ ggml.c | 16 +- k_quants.c | 214 +++++++++++++++++++++ llama.cpp | 4 + llama.h | 1 + 8 files changed, 381 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 0be2b2334..4db73357e 100644 --- a/Makefile +++ b/Makefile @@ -229,7 +229,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o k_quants.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @@ -269,7 +269,7 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ./$@ -vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) +vdot: pocs/vdot/vdot.cpp ggml.o k_quants.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) .PHONY: tests clean diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 085fdde3c..6e4f7e1e0 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -282,8 +282,9 @@ int main(int argc, char ** argv) { break; } int j; - for (j = 0; j < GGML_TYPE_COUNT && strcmp(argv[i], ggml_type_name((ggml_type) j)) != 0; j++) { - // find match + for (j = 0; j < GGML_TYPE_COUNT; ++j) { + const auto * name = ggml_type_name((ggml_type) j); + if (name && strcmp(argv[i], name) == 0) break; } if (j < GGML_TYPE_COUNT) { params.include_types.push_back((ggml_type) j); diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 769dd36a4..1f17a2acb 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -12,6 +12,7 @@ static const std::map LLAMA_FTYPE_MAP = { {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, + {"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K}, }; bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 98170a3ae..84c5e8b7d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); +typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); // QK = number of values after dequantization // QR = QK / number of values before dequantization @@ -83,6 +85,18 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +//================================= k-quants + +#define QK_K 256 + +typedef struct { + uint8_t hmask[QK_K/8]; + uint8_t qs[QK_K/4]; // nibbles / quants + uint8_t scales[3*QK_K/64]; + half d; +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); + #define WARP_SIZE 32 #define CUDA_MUL_BLOCK_SIZE 256 @@ -184,6 +198,80 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int v1 = vi1*d; } +//================================== k-quants + +static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + int i = blockIdx.x; + int tid = threadIdx.x; + int n = tid / 4; + int j = tid - 4*n; + + const block_q3_K * x = (const block_q3_K *) vx; + + float * y = yy + i*QK_K + 128*n + 32*j; + + float d_all = x[i].d; + + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + memcpy(aux, x[i].scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j; + float dl; + int shift = 2*j; + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + + dl = d_all * (scales[is++] - 32); + for (int l = 16; l < 32; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + +} + +static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * y, float & result) { + + const block_q3_K * x = (const block_q3_K *) vx; + + int n = iqs / 128; + int iqsn = iqs - 128*n; + int j = iqsn / 32; + int l = iqsn - 32*j; + int shift = 2*j; + int is = 8*n + 2*j + l/16; + uint8_t m = 1 << (4*n + j); + + const float d = x[ib].d; + const uint8_t * q = x[ib].qs + 32*n; + const uint8_t * hm = x[ib].hmask; + + int8_t us = is < 4 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is+0] >> 4) & 3) << 4) : + (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is-4] >> 6) & 3) << 4); + float scale = d * (us - 32); + + float sum = 0; + for (int k = 0; k < 8; ++k) { + int8_t ql = (q[l + k] >> shift) & 3; + sum += y[iqs + k] * (ql - ((hm[l + k] & m) ? 0 : 4)); + } + result = sum * scale; +} + static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const half * x = (const half *) vx; @@ -258,6 +346,43 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } +template +static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + const int iter_stride = QK_K; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + + float tmp = 0; // partial sum for thread in warp + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK_K; // x block index + const int iqs = col%QK_K; // x quant index + const int iybs = col - col%QK_K; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 8) { + float v; + dot_kernel(vx, ib, iqs + j, y + iybs, v); + tmp += v; + } + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; mul_f32<<>>(x, y, dst, kx, ky); @@ -288,6 +413,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu dequantize_block<<>>(vx, y, k); } +static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q3_K<<>>(vx, y); +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); @@ -328,6 +458,12 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols); } +static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat_vec_k<<>>(vx, y, dst, ncols); +} + static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<32, 1, convert_f16><<>>(vx, y, k); @@ -353,6 +489,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_row_q8_0_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; case GGML_TYPE_F16: return convert_fp16_to_fp32_cuda; default: @@ -372,6 +510,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t return dequantize_mul_mat_vec_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_mul_mat_vec_q8_0_cuda; + case GGML_TYPE_Q3_K: + return dequantize_mul_mat_vec_q3_K_cuda; case GGML_TYPE_F16: return convert_mul_mat_vec_f16_cuda; default: @@ -790,12 +930,14 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); // compute + //printf("Calling dmmv\n"); dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); CUDA_CHECK(cudaGetLastError()); } else { // general dequantization kernel + cuBLAS matrix matrix multiplication float * c_X = d_X + i * x_ne; +//typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); // convert src0 to fp32 on device to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml.c b/ggml.c index 1fd1e4a5e..0e849a621 100644 --- a/ggml.c +++ b/ggml.c @@ -1570,8 +1570,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K, .quantize_row_q = quantize_row_q3_K, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference, - .quantize_row_q_dot = NULL, //quantize_row_q8_K, - .vec_dot_q = NULL, //ggml_vec_dot_q3_K_q8_K, + .quantize_row_q_dot = quantize_row_q8_K, + .vec_dot_q = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, }, }; @@ -7602,6 +7602,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7905,6 +7906,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q3_K: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8027,6 +8029,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q3_K: default: { GGML_ASSERT(false); @@ -10120,6 +10123,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q3_K: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -10303,6 +10307,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q3_K: default: { GGML_ASSERT(false); @@ -10468,6 +10473,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q3_K: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -16092,6 +16098,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; result = ggml_quantize_q8_0(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q3_K * block = (block_q3_K*)dst + start / QK_K; + result = ggml_quantize_q3_K(src + start, block, n, n, hist); + } break; default: assert(false); } diff --git a/k_quants.c b/k_quants.c index a929bb6ad..86cc90616 100644 --- a/k_quants.c +++ b/k_quants.c @@ -5,6 +5,8 @@ #include #include +#include + #undef MIN #undef MAX #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -456,3 +458,215 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { quantize_row_q8_K_reference(x, y, k); } +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// shuffle to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +#endif + +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef z__ARM_NEON + // TODO +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} diff --git a/llama.cpp b/llama.cpp index 47b4c8dd7..c0c072a9b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -507,6 +507,7 @@ struct llama_file_loader { case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: break; default: { throw format("unrecognized tensor type %u\n", shard.type); @@ -582,6 +583,7 @@ struct llama_file_saver { case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: break; default: LLAMA_ASSERT(false); } @@ -898,6 +900,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q3_K: return "mostly Q3_K"; default: return "unknown, may not work"; } } @@ -2063,6 +2066,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_Q3_K: quantized_type = GGML_TYPE_Q3_K; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index c6b0a2889..de6671171 100644 --- a/llama.h +++ b/llama.h @@ -94,6 +94,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K = 10,// except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params();