From e51ce72e03fe487f5b1d614287a6724559882afe Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 1 Jun 2023 14:01:25 +0300 Subject: [PATCH] Fixed bug in Q2_K CUDA dot product kernel Stranegly enough, for the few prompts I tried with the 7B model the responses looked perfectly reasonable. Only realized something is not quite right when I tried the larger models and started getting nonse back. In any case, Q2_K single token evaluation time on an RTX 4080 in a Ryzen7950X box iusing CUDA and model fully loaded on the GPU are ~15.5 ms for 7B, ~25.4 ms for 13B, and ~55.8 ms for 30B. The max number of layers that fit in VRAM for The 65B is 32. With that, we get ~330 ms per token, which is not that much faster than just running on the CPU (~470 ms per token). --- ggml-cuda.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 71e99ad85..1a13d2b8e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -278,7 +278,7 @@ static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) - + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[1] >> 4)) + + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); @@ -753,8 +753,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 2, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<>>(vx, y, dst, ncols); + const int ny = 2; + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {