From 8e3057b24b4c84355c9ecbbcef3e14d18d24ae29 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Mon, 12 Jun 2023 20:23:16 +0200 Subject: [PATCH] Removed obsolete code, fixed multi GPU --- ggml-cuda.cu | 212 +++++++++------------------------------------------ 1 file changed, 37 insertions(+), 175 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 921b5a03b..de0fd447d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -743,7 +743,7 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y } } -static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y) { +static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const half * x = (half *) vx; // const int col_x = blockDim.x*blockIdx.x + threadIdx.x; // const int row_x = blockDim.y*blockIdx.y + threadIdx.y; @@ -752,7 +752,6 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int nrows_y = ncols_x; - const int ncols_dst = ncols_y; const int nrows_dst = nrows_x; const int row_dst = row_x; @@ -775,11 +774,11 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl // y is not transposed but permuted const int iy = channel*nrows_y + row_y; - // dst is not transposed and not permuted - tmp += xi * y[iy]; } - const int idst = channel*ncols_dst*nrows_dst + row_dst; + + // dst is not transposed and not permuted + const int idst = channel*nrows_dst + row_dst; // sum up partial sums and write back result __syncthreads(); @@ -1143,10 +1142,10 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } -static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) { +static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); } static void ggml_mul_mat_vec_nc_f16_f32_cuda( @@ -2024,14 +2023,6 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; - // if (strcmp(dst->name, "KQ") == 0) { - // fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n", - // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - // src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); - // return true; - // } - // TODO: find the optimal values for these if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && @@ -2043,169 +2034,60 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } -void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ - GGML_ASSERT(!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1)); +void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - // GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU); - const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - GGML_ASSERT(ne00 % 8 == 0); - GGML_ASSERT(ne03 == 1); - const size_t src0_size = ggml_nbytes(src0); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - GGML_ASSERT(ne10 % 4 == 0); - GGML_ASSERT(ne13 == 1); - GGML_ASSERT(ne12 == ne02); - const size_t src1_size = ggml_nbytes(src1); - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - GGML_ASSERT(ne2 == ne02); - const size_t dst_size = ggml_nbytes(dst); - - const int64_t nb1 = dst->nb[1]; - const int64_t nb2 = dst->nb[2]; - const int64_t nb3 = dst->nb[3]; + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; - void * src0_ddq; - float * src1_ddf; - float * dst_ddf; - if (src0->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size)); - CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main)); - } else { - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - src0_ddq = src0_extra->data_device[g_main_device]; - // CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0))); - } - if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size)); - CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main)); - } else { - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - src1_ddf = (float *) src1_extra->data_device[g_main_device]; - // CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1))); - } - if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size)); - } else { - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - dst_ddf = (float *) dst_extra->data_device[g_main_device]; - } + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; - for (int64_t i11 = 0; i11 < ne11; ++i11) { - float * src1_ddf_i = src1_ddf + i11 * ne10*ne12; - float * dst_ddf_i = dst_ddf + i11 * ne0; + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf_i, dst_ddf_i, ne00, ne01, ne02, ne11, cudaStream_main); - } + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); CUDA_CHECK(cudaDeviceSynchronize()); - - if (dst->backend == GGML_BACKEND_CPU) { - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i = i3*ne2*ne1 + i2*ne1 + i1; - float * dst_ddf_i = dst_ddf + i*ne0; - float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main)); - } - } - } - } - - CUDA_CHECK(cudaDeviceSynchronize()); - if (src0->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaFree(src0_ddq)); - } - if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaFree(src1_ddf)); - } - if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaFree(dst_ddf)); - } } -void ggml_cuda_mul_mat_vec_nc_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - // GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU); - const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - GGML_ASSERT(ne03 == 1); - const size_t src0_size = ggml_nbytes(src0); const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - GGML_ASSERT(ne13 == 1); - GGML_ASSERT(ne12 == ne02); - const size_t src1_size = ggml_nbytes(src1); - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - GGML_ASSERT(ne2 == ne02); - const size_t dst_size = ggml_nbytes(dst); - - const int64_t nb1 = dst->nb[1]; - const int64_t nb2 = dst->nb[2]; - const int64_t nb3 = dst->nb[3]; - + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; - void * src0_ddq; - float * src1_ddf; - float * dst_ddf; - if (src0->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size)); - CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main)); - } else { - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - src0_ddq = src0_extra->data_device[g_main_device]; - // CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0))); - } - if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size)); - CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main)); - } else { - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - src1_ddf = (float *) src1_extra->data_device[g_main_device]; - // CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1))); - } - if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size)); - } else { - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - dst_ddf = (float *) dst_extra->data_device[g_main_device]; - } + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; const int row_stride_x = nb01 / sizeof(half); const int channel_stride_x = nb02 / sizeof(half); @@ -2213,37 +2095,16 @@ void ggml_cuda_mul_mat_vec_nc_f16_f32(const ggml_tensor * src0, const ggml_tenso ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); CUDA_CHECK(cudaDeviceSynchronize()); - - if (dst->backend == GGML_BACKEND_CPU) { - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i = i3*ne2*ne1 + i2*ne1 + i1; - float * dst_ddf_i = dst_ddf + i*ne0; - float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main)); - } - } - } - } - - CUDA_CHECK(cudaDeviceSynchronize()); - if (src0->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaFree(src0_ddq)); - } - if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaFree(src1_ddf)); - } - if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaFree(dst_ddf)); - } } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst); - } else if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { - ggml_cuda_mul_mat_vec_nc_f16_f32(src0, src1, dst); + bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + + if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + ggml_cuda_mul_mat_vec_p021(src0, src1, dst); + } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { + ggml_cuda_mul_mat_vec_nc(src0, src1, dst); }else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { @@ -2288,6 +2149,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens const int64_t nb11 = src1->nb[1]; const int64_t nb12 = src1->nb[2]; + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;