From 5a0ecf768d1d69e4f40c3e922f72e88112580d2f Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 13 May 2023 07:14:27 +0200 Subject: [PATCH] More readable dequantize_mul_mat_vec logic --- ggml-cuda.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 66a2b0f93..161b54682 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; const ggml_type type = src0->type; + const bool mul_mat_vec = ne11 == 1; const float alpha = 1.0f; const float beta = 0.0f; @@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor size_t x_size, y_size, d_size, q_size; float * d_X; - if (ne11 > 1) { + if (!mul_mat_vec) { d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); } float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); @@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } else { GGML_ASSERT(false); } - if (ne11 == 1) { + if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); // copy src1 to device @@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); CUDA_CHECK(cudaGetLastError()); - } else { + } else { // general dequantization kernel + cuBLAS matrix matrix multiplication float * c_X = d_X + i * x_ne; // convert src0 to fp32 on device @@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } CUDA_CHECK(cudaDeviceSynchronize()); - if (ne11 > 1) { + if (!mul_mat_vec) { ggml_cuda_pool_free(d_X, x_size); } ggml_cuda_pool_free(d_Y, y_size);