From f421b32d5a603bc9d5b2309eabb779d12980b809 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 2 Apr 2024 20:09:25 +0200 Subject: [PATCH] cuda/argsort : use shared memory instead of pool memory --- ggml-cuda/argsort.cu | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/argsort.cu b/ggml-cuda/argsort.cu index 4dd41a671..0e9ef4966 100644 --- a/ggml-cuda/argsort.cu +++ b/ggml-cuda/argsort.cu @@ -8,7 +8,7 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_pad, const int ncols, int ncols_pad) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -18,7 +18,7 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_p } const float * x_row = x + row * ncols; - int * dst_row = dst_pad + row * ncols_pad; + extern __shared__ int dst_row[]; // initialize indices dst_row[col] = col; @@ -69,18 +69,16 @@ static void argsort_f32_i32_cuda(ggml_backend_cuda_context & ctx, const float * // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); - ggml_cuda_pool_alloc dst_padded_alloc; - int * dst_padded = dst; - if (ncols_pad > ncols) { - dst_padded = dst_padded_alloc.alloc(ctx.pool(), nrows * ncols_pad); - } - const dim3 block_dims(ncols_pad, 1, 1); const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, dst_padded, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, dst_padded, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else { GGML_ASSERT(false); }