More readable dequantize_mul_mat_vec logic
This commit is contained in:
parent
9da44fdcb3
commit
5a0ecf768d
1 changed files with 5 additions and 4 deletions
|
@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
const ggml_type type = src0->type;
|
||||
const bool mul_mat_vec = ne11 == 1;
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
|
||||
size_t x_size, y_size, d_size, q_size;
|
||||
float * d_X;
|
||||
if (ne11 > 1) {
|
||||
if (!mul_mat_vec) {
|
||||
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||
}
|
||||
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||
|
@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
if (ne11 == 1) {
|
||||
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// copy src1 to device
|
||||
|
@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
} else { // general dequantization kernel + cuBLAS matrix matrix multiplication
|
||||
float * c_X = d_X + i * x_ne;
|
||||
|
||||
// convert src0 to fp32 on device
|
||||
|
@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
if (ne11 > 1) {
|
||||
if (!mul_mat_vec) {
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
}
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue