More readable dequantize_mul_mat_vec logic

This commit is contained in:
JohannesGaessler 2023-05-13 07:14:27 +02:00
parent 9da44fdcb3
commit 5a0ecf768d

View file

@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const ggml_type type = src0->type;
const bool mul_mat_vec = ne11 == 1;
const float alpha = 1.0f;
const float beta = 0.0f;
@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
size_t x_size, y_size, d_size, q_size;
float * d_X;
if (ne11 > 1) {
if (!mul_mat_vec) {
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
}
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
} else {
GGML_ASSERT(false);
}
if (ne11 == 1) {
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
// copy src1 to device
@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError());
} else {
} else { // general dequantization kernel + cuBLAS matrix matrix multiplication
float * c_X = d_X + i * x_ne;
// convert src0 to fp32 on device
@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
}
CUDA_CHECK(cudaDeviceSynchronize());
if (ne11 > 1) {
if (!mul_mat_vec) {
ggml_cuda_pool_free(d_X, x_size);
}
ggml_cuda_pool_free(d_Y, y_size);