diff --git a/ggml-cuda/argsort.cu b/ggml-cuda/argsort.cu index 4dd41a671..0e9ef4966 100644 --- a/ggml-cuda/argsort.cu +++ b/ggml-cuda/argsort.cu @@ -8,7 +8,7 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_pad, const int ncols, int ncols_pad) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -18,7 +18,7 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_p } const float * x_row = x + row * ncols; - int * dst_row = dst_pad + row * ncols_pad; + extern __shared__ int dst_row[]; // initialize indices dst_row[col] = col; @@ -69,18 +69,16 @@ static void argsort_f32_i32_cuda(ggml_backend_cuda_context & ctx, const float * // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); - ggml_cuda_pool_alloc dst_padded_alloc; - int * dst_padded = dst; - if (ncols_pad > ncols) { - dst_padded = dst_padded_alloc.alloc(ctx.pool(), nrows * ncols_pad); - } - const dim3 block_dims(ncols_pad, 1, 1); const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, dst_padded, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, dst_padded, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else { GGML_ASSERT(false); }