From 5288bd58960ae25501deca21c8d02b4b289c643e Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Sat, 1 Feb 2025 09:37:29 +0530 Subject: [PATCH] Argsort: move to a separate file --- ggml/src/ggml-sycl/argsort.cpp | 120 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/argsort.hpp | 8 ++ ggml/src/ggml-sycl/backend.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 129 ------------------------------- 4 files changed, 129 insertions(+), 129 deletions(-) create mode 100644 ggml/src/ggml-sycl/argsort.cpp create mode 100644 ggml/src/ggml-sycl/argsort.hpp diff --git a/ggml/src/ggml-sycl/argsort.cpp b/ggml/src/ggml-sycl/argsort.cpp new file mode 100644 index 000000000..74cb0afd6 --- /dev/null +++ b/ggml/src/ggml-sycl/argsort.cpp @@ -0,0 +1,120 @@ +#include "argsort.hpp" + +template +static inline void ggml_sycl_swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template +__dpct_inline__ static void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, + const sycl::nd_item<3> & item_ct1, uint8_t * dpct_local) { + // bitonic sort + int col = item_ct1.get_local_id(2); + int row = item_ct1.get_group(1); + + if (col >= ncols_pad) { + return; + } + + const float * x_row = x + row * ncols; + auto dst_row = (int *) dpct_local; + + // initialize indices + dst_row[col] = col; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && + (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]]))) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && + (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]]))) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } + } + } + /* + DPCT1118:1: SYCL group functions and algorithms must be encountered + in converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +static void argsort_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, + queue_ptr stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const sycl::range<3> block_dims(1, 1, ncols_pad); + const sycl::range<3> block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + if (order == GGML_SORT_ORDER_ASC) { + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh); + + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_argsort_f32_i32( + x, dst, ncols, ncols_pad, item_ct1, + dpct_local_acc_ct1.get_multi_ptr().get()); + }); + }); + } else if (order == GGML_SORT_ORDER_DESC) { + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh); + + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_argsort_f32_i32( + x, dst, ncols, ncols_pad, item_ct1, + dpct_local_acc_ct1.get_multi_ptr().get()); + }); + }); + } else { + GGML_ABORT("fatal error"); + } +} + +inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + dpct::queue_ptr main_stream = ctx.stream(); + const float * src0_dd = static_cast(dst->src[0]->data); + int32_t * dst_dd = static_cast(dst->data); + + argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_argsort(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} \ No newline at end of file diff --git a/ggml/src/ggml-sycl/argsort.hpp b/ggml/src/ggml-sycl/argsort.hpp new file mode 100644 index 000000000..e79d20e8a --- /dev/null +++ b/ggml/src/ggml-sycl/argsort.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_ARGSORT_HPP +#define GGML_SYCL_ARGSORT_HPP + +#include "common.hpp" + +void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_ARGSORT_HPP diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 05bc85ded..ece5449d6 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -31,6 +31,7 @@ #include "element_wise.hpp" #include "binbcast.hpp" #include "argmax.hpp" +#include "argsort.hpp" #include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f9ea4258e..f4d606b4a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1730,70 +1730,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols, } -template -static inline void ggml_sycl_swap(T & a, T & b) { - T tmp = a; - a = b; - b = tmp; -} - -template -__dpct_inline__ static void -k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad, - const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) { - // bitonic sort - int col = item_ct1.get_local_id(2); - int row = item_ct1.get_group(1); - - if (col >= ncols_pad) { - return; - } - - const float * x_row = x + row * ncols; - auto dst_row = (int *)dpct_local; - - // initialize indices - dst_row[col] = col; - - item_ct1.barrier(sycl::access::fence_space::local_space); - - for (int k = 2; k <= ncols_pad; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - ggml_sycl_swap(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - ggml_sycl_swap(dst_row[col], dst_row[ixj]); - } - } - } - /* - DPCT1118:1: SYCL group functions and algorithms must be encountered - in converged control flow. You may need to adjust the code. - */ - item_ct1.barrier(sycl::access::fence_space::local_space); - } - } - - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; - } -} - - static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past, const sycl::nd_item<3> &item_ct1) { const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + @@ -2304,49 +2240,6 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, - const int nrows, ggml_sort_order order, - queue_ptr stream) { - // bitonic sort requires ncols to be power of 2 - const int ncols_pad = next_power_of_2(ncols); - - const sycl::range<3> block_dims(1, 1, ncols_pad); - const sycl::range<3> block_nums(1, nrows, 1); - const size_t shared_mem = ncols_pad * sizeof(int); - - if (order == GGML_SORT_ORDER_ASC) { - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor dpct_local_acc_ct1( - sycl::range<1>(shared_mem), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_argsort_f32_i32( - x, dst, ncols, ncols_pad, item_ct1, - dpct_local_acc_ct1.get_multi_ptr() - .get()); - }); - }); - } else if (order == GGML_SORT_ORDER_DESC) { - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor dpct_local_acc_ct1( - sycl::range<1>(shared_mem), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_argsort_f32_i32( - x, dst, ncols, ncols_pad, item_ct1, - dpct_local_acc_ct1.get_multi_ptr() - .get()); - }); - }); - } else { - GGML_ABORT("fatal error"); - } -} - static void diag_mask_inf_f32_sycl(const float *x, float *dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, @@ -2678,22 +2571,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); } -inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_I32); - - const int64_t ncols = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - dpct::queue_ptr main_stream = ctx.stream(); - const float * src0_dd = static_cast(dst->src[0]->data); - int32_t * dst_dd = static_cast(dst->data); - - argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream); -} - inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); @@ -3758,12 +3635,6 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds ggml_sycl_op_sum_rows(ctx, dst); } -static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_argsort(ctx, dst); -} - - void ggml_sycl_set_main_device(const int main_device) try { if (dpct::get_current_device_id() == static_cast (main_device)) { return;