More readable dequantize_mul_mat_vec logic

This commit is contained in:
JohannesGaessler 2023-05-13 07:14:27 +02:00
parent 9da44fdcb3
commit 5a0ecf768d

View file

@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
const ggml_type type = src0->type; const ggml_type type = src0->type;
const bool mul_mat_vec = ne11 == 1;
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
size_t x_size, y_size, d_size, q_size; size_t x_size, y_size, d_size, q_size;
float * d_X; float * d_X;
if (ne11 > 1) { if (!mul_mat_vec) {
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
} }
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
if (ne11 == 1) { if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
// copy src1 to device // copy src1 to device
@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} else { } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
float * c_X = d_X + i * x_ne; float * c_X = d_X + i * x_ne;
// convert src0 to fp32 on device // convert src0 to fp32 on device
@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
} }
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
if (ne11 > 1) { if (!mul_mat_vec) {
ggml_cuda_pool_free(d_X, x_size); ggml_cuda_pool_free(d_X, x_size);
} }
ggml_cuda_pool_free(d_Y, y_size); ggml_cuda_pool_free(d_Y, y_size);