From 1b5d78d3ee237571f9035332cd7c25ad7c4ed368 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 6 Apr 2024 15:19:47 +0200 Subject: [PATCH] minor --- ggml-cuda.cu | 61 ++++++++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 74a50594f..297fdbe13 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1961,22 +1961,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } struct mmid_row_mapping { - int64_t i1; - int64_t i2; + int32_t i1; + int32_t i2; }; -static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous, - int * cur_src1_row, mmid_row_mapping * row_mapping, - const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0, - int64_t ids_ne1, int64_t n_ids, - int64_t ne11, +static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, + int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, + const char * ids_dev, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, size_t nb11, size_t nb12) { - int64_t iid1 = blockIdx.x; - int64_t id = blockIdx.y; - - if (iid1 >= ids_ne1 || id >= n_ids) { - return; - } + int32_t iid1 = blockIdx.x; + int32_t id = blockIdx.y; const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0); @@ -1994,31 +1989,27 @@ static __global__ void k_copy_src1_to_contiguous(const char * src1_original, cha } __syncthreads(); - const char * src1_row_original = src1_original + i11*nb11 + i12*nb12; - char * src1_row_contiguous = src1_contiguous + src1_row*nb11; + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - for (int i = threadIdx.x; i < nb11; i += blockDim.x) { + for (int i = threadIdx.x; i < ne10; i += blockDim.x) { src1_row_contiguous[i] = src1_row_original[i]; } } -static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous, - const mmid_row_mapping * row_mapping, - int64_t n_rows, - int64_t nb1, int64_t nb2) { - int64_t i = blockIdx.x; +static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne0, + size_t nb1, size_t nb2) { + int32_t i = blockIdx.x; - if (i >= n_rows) { - return; - } + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; - const int64_t i1 = row_mapping[i].i1; - const int64_t i2 = row_mapping[i].i2; + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); - const char * dst_row_contiguous = dst_contiguous + i*nb1; - char * dst_row_original = dst_original + i1*nb1 + i2*nb2; - - for (int j = threadIdx.x; j < nb1; j += blockDim.x) { + for (int j = threadIdx.x; j < ne0; j += blockDim.x) { dst_row_original[j] = dst_row_contiguous[j]; } } @@ -2129,14 +2120,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); { - dim3 block_dims(std::min((uint)nb11, 1024u)); + dim3 block_dims(std::min((uint)ne10, 512u)); dim3 grid_dims(ids->ne[1], n_ids); k_copy_src1_to_contiguous<<>>( src1_original, src1_contiguous.get(), dev_cur_src1_row.get(), dev_row_mapping.get(), ids_dev, i02, ids->nb[1], ids->nb[0], - ids->ne[1], n_ids, - ne11, + ne11, ne10, nb11, nb12); CUDA_CHECK(cudaGetLastError()); } @@ -2161,12 +2151,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * #ifndef MMID_MEMCPY { - dim3 block_dims(std::min((uint)nb1, 1024u)); + dim3 block_dims(std::min((uint)ne0, 512u)); dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_contiguous.get(), dev_row_mapping.get(), - num_src1_rows, nb1, nb2); + ne0, + nb1, nb2); CUDA_CHECK(cudaGetLastError()); } #endif