cuda/argsort : use shared memory instead of pool memory

This commit is contained in:
slaren 2024-04-02 20:09:25 +02:00
parent 9530398013
commit f421b32d5a

View file

@ -8,7 +8,7 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
} }
template<ggml_sort_order order> template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_pad, const int ncols, int ncols_pad) { static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
// bitonic sort // bitonic sort
int col = threadIdx.x; int col = threadIdx.x;
int row = blockIdx.y; int row = blockIdx.y;
@ -18,7 +18,7 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_p
} }
const float * x_row = x + row * ncols; const float * x_row = x + row * ncols;
int * dst_row = dst_pad + row * ncols_pad; extern __shared__ int dst_row[];
// initialize indices // initialize indices
dst_row[col] = col; dst_row[col] = col;
@ -69,18 +69,16 @@ static void argsort_f32_i32_cuda(ggml_backend_cuda_context & ctx, const float *
// bitonic sort requires ncols to be power of 2 // bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols); const int ncols_pad = next_power_of_2(ncols);
ggml_cuda_pool_alloc<int> dst_padded_alloc;
int * dst_padded = dst;
if (ncols_pad > ncols) {
dst_padded = dst_padded_alloc.alloc(ctx.pool(), nrows * ncols_pad);
}
const dim3 block_dims(ncols_pad, 1, 1); const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1); const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, dst_padded, ncols, ncols_pad); k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) { } else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, dst_padded, ncols, ncols_pad); k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }