cuda/argsort : use shared memory instead of pool memory
This commit is contained in:
parent
9530398013
commit
f421b32d5a
1 changed files with 8 additions and 10 deletions
|
@ -8,7 +8,7 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<ggml_sort_order order>
|
template<ggml_sort_order order>
|
||||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_pad, const int ncols, int ncols_pad) {
|
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
int col = threadIdx.x;
|
int col = threadIdx.x;
|
||||||
int row = blockIdx.y;
|
int row = blockIdx.y;
|
||||||
|
@ -18,7 +18,7 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_p
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * x_row = x + row * ncols;
|
const float * x_row = x + row * ncols;
|
||||||
int * dst_row = dst_pad + row * ncols_pad;
|
extern __shared__ int dst_row[];
|
||||||
|
|
||||||
// initialize indices
|
// initialize indices
|
||||||
dst_row[col] = col;
|
dst_row[col] = col;
|
||||||
|
@ -69,18 +69,16 @@ static void argsort_f32_i32_cuda(ggml_backend_cuda_context & ctx, const float *
|
||||||
// bitonic sort requires ncols to be power of 2
|
// bitonic sort requires ncols to be power of 2
|
||||||
const int ncols_pad = next_power_of_2(ncols);
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<int> dst_padded_alloc;
|
|
||||||
int * dst_padded = dst;
|
|
||||||
if (ncols_pad > ncols) {
|
|
||||||
dst_padded = dst_padded_alloc.alloc(ctx.pool(), nrows * ncols_pad);
|
|
||||||
}
|
|
||||||
|
|
||||||
const dim3 block_dims(ncols_pad, 1, 1);
|
const dim3 block_dims(ncols_pad, 1, 1);
|
||||||
const dim3 block_nums(1, nrows, 1);
|
const dim3 block_nums(1, nrows, 1);
|
||||||
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
|
||||||
|
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, dst_padded, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, dst_padded, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue