From 4531b029eece7c52e526067db1bf12b8adaa3114 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 2 Apr 2024 01:11:59 +0200 Subject: [PATCH] cuda : support non-pow-2 number of experts --- ggml-cuda/argsort.cu | 60 ++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/ggml-cuda/argsort.cu b/ggml-cuda/argsort.cu index 1333287e4..4dd41a671 100644 --- a/ggml-cuda/argsort.cu +++ b/ggml-cuda/argsort.cu @@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_pad, const int ncols, int ncols_pad) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; - if (col >= ncols) return; + if (col >= ncols_pad) { + return; + } const float * x_row = x + row * ncols; - int * dst_row = dst + row * ncols; + int * dst_row = dst_pad + row * ncols_pad; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + __syncthreads(); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { ggml_cuda_swap(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { ggml_cuda_swap(dst_row[col], dst_row[ixj]); } } @@ -41,18 +50,37 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n __syncthreads(); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { - // bitonic sort requires ncols to be power of 2 - GGML_ASSERT((ncols & (ncols - 1)) == 0); +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} - const dim3 block_dims(ncols, 1, 1); +static void argsort_f32_i32_cuda(ggml_backend_cuda_context & ctx, const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + ggml_cuda_pool_alloc dst_padded_alloc; + int * dst_padded = dst; + if (ncols_pad > ncols) { + dst_padded = dst_padded_alloc.alloc(ctx.pool(), nrows * ncols_pad); + } + + const dim3 block_dims(ncols_pad, 1, 1); const dim3 block_nums(1, nrows, 1); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols); + k_argsort_f32_i32<<>>(x, dst, dst_padded, ncols, ncols_pad); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols); + k_argsort_f32_i32<<>>(x, dst, dst_padded, ncols, ncols_pad); } else { GGML_ASSERT(false); } @@ -73,5 +101,5 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + argsort_f32_i32_cuda(ctx, src0_d, (int *)dst_d, ncols, nrows, order, stream); }