Argsort: move to a separate file
This commit is contained in:
parent
95a09ab505
commit
5288bd5896
4 changed files with 129 additions and 129 deletions
120
ggml/src/ggml-sycl/argsort.cpp
Normal file
120
ggml/src/ggml-sycl/argsort.cpp
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
#include "argsort.hpp"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline void ggml_sycl_swap(T & a, T & b) {
|
||||||
|
T tmp = a;
|
||||||
|
a = b;
|
||||||
|
b = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <ggml_sort_order order>
|
||||||
|
__dpct_inline__ static void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad,
|
||||||
|
const sycl::nd_item<3> & item_ct1, uint8_t * dpct_local) {
|
||||||
|
// bitonic sort
|
||||||
|
int col = item_ct1.get_local_id(2);
|
||||||
|
int row = item_ct1.get_group(1);
|
||||||
|
|
||||||
|
if (col >= ncols_pad) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * x_row = x + row * ncols;
|
||||||
|
auto dst_row = (int *) dpct_local;
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
dst_row[col] = col;
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||||
|
for (int j = k / 2; j > 0; j /= 2) {
|
||||||
|
int ixj = col ^ j;
|
||||||
|
if (ixj > col) {
|
||||||
|
if ((col & k) == 0) {
|
||||||
|
if (dst_row[col] >= ncols ||
|
||||||
|
(dst_row[ixj] < ncols &&
|
||||||
|
(order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||||
|
x_row[dst_row[col]] < x_row[dst_row[ixj]]))) {
|
||||||
|
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (dst_row[ixj] >= ncols ||
|
||||||
|
(dst_row[col] < ncols &&
|
||||||
|
(order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||||
|
x_row[dst_row[col]] > x_row[dst_row[ixj]]))) {
|
||||||
|
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
||||||
|
in converged control flow. You may need to adjust the code.
|
||||||
|
*/
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the result to dst without the padding
|
||||||
|
if (col < ncols) {
|
||||||
|
dst[row * ncols + col] = dst_row[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void argsort_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order,
|
||||||
|
queue_ptr stream) {
|
||||||
|
// bitonic sort requires ncols to be power of 2
|
||||||
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
|
||||||
|
const sycl::range<3> block_dims(1, 1, ncols_pad);
|
||||||
|
const sycl::range<3> block_nums(1, nrows, 1);
|
||||||
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
||||||
|
x, dst, ncols, ncols_pad, item_ct1,
|
||||||
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
||||||
|
x, dst, ncols, ncols_pad, item_ct1,
|
||||||
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||||
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
const int64_t ncols = dst->src[0]->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||||
|
|
||||||
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
|
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
||||||
|
|
||||||
|
argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream);
|
||||||
|
} catch (const sycl::exception & exc) {
|
||||||
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_argsort(ctx, dst);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
8
ggml/src/ggml-sycl/argsort.hpp
Normal file
8
ggml/src/ggml-sycl/argsort.hpp
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#ifndef GGML_SYCL_ARGSORT_HPP
|
||||||
|
#define GGML_SYCL_ARGSORT_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_ARGSORT_HPP
|
|
@ -31,6 +31,7 @@
|
||||||
#include "element_wise.hpp"
|
#include "element_wise.hpp"
|
||||||
#include "binbcast.hpp"
|
#include "binbcast.hpp"
|
||||||
#include "argmax.hpp"
|
#include "argmax.hpp"
|
||||||
|
#include "argsort.hpp"
|
||||||
#include "gla.hpp"
|
#include "gla.hpp"
|
||||||
|
|
||||||
#endif // GGML_SYCL_BACKEND_HPP
|
#endif // GGML_SYCL_BACKEND_HPP
|
||||||
|
|
|
@ -1730,70 +1730,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
static inline void ggml_sycl_swap(T & a, T & b) {
|
|
||||||
T tmp = a;
|
|
||||||
a = b;
|
|
||||||
b = tmp;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <ggml_sort_order order>
|
|
||||||
__dpct_inline__ static void
|
|
||||||
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
|
||||||
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
|
|
||||||
// bitonic sort
|
|
||||||
int col = item_ct1.get_local_id(2);
|
|
||||||
int row = item_ct1.get_group(1);
|
|
||||||
|
|
||||||
if (col >= ncols_pad) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float * x_row = x + row * ncols;
|
|
||||||
auto dst_row = (int *)dpct_local;
|
|
||||||
|
|
||||||
// initialize indices
|
|
||||||
dst_row[col] = col;
|
|
||||||
|
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
||||||
|
|
||||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
|
||||||
for (int j = k / 2; j > 0; j /= 2) {
|
|
||||||
int ixj = col ^ j;
|
|
||||||
if (ixj > col) {
|
|
||||||
if ((col & k) == 0) {
|
|
||||||
if (dst_row[col] >= ncols ||
|
|
||||||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
|
||||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
|
||||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
|
||||||
) {
|
|
||||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (dst_row[ixj] >= ncols ||
|
|
||||||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
|
||||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
|
||||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
|
||||||
) {
|
|
||||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
|
||||||
in converged control flow. You may need to adjust the code.
|
|
||||||
*/
|
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy the result to dst without the padding
|
|
||||||
if (col < ncols) {
|
|
||||||
dst[row * ncols + col] = dst_row[col];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||||
|
@ -2304,49 +2240,6 @@ static int next_power_of_2(int x) {
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
||||||
const int nrows, ggml_sort_order order,
|
|
||||||
queue_ptr stream) {
|
|
||||||
// bitonic sort requires ncols to be power of 2
|
|
||||||
const int ncols_pad = next_power_of_2(ncols);
|
|
||||||
|
|
||||||
const sycl::range<3> block_dims(1, 1, ncols_pad);
|
|
||||||
const sycl::range<3> block_nums(1, nrows, 1);
|
|
||||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
|
||||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
|
||||||
sycl::range<1>(shared_mem), cgh);
|
|
||||||
|
|
||||||
cgh.parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
|
||||||
x, dst, ncols, ncols_pad, item_ct1,
|
|
||||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
|
||||||
.get());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
|
||||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
|
||||||
sycl::range<1>(shared_mem), cgh);
|
|
||||||
|
|
||||||
cgh.parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
|
||||||
x, dst, ncols, ncols_pad, item_ct1,
|
|
||||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
|
||||||
.get());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
||||||
const int ncols_x, const int nrows_x,
|
const int ncols_x, const int nrows_x,
|
||||||
const int rows_per_channel, const int n_past,
|
const int rows_per_channel, const int n_past,
|
||||||
|
@ -2678,22 +2571,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
|
||||||
|
|
||||||
const int64_t ncols = dst->src[0]->ne[0];
|
|
||||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
||||||
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
|
||||||
|
|
||||||
argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
|
@ -3758,12 +3635,6 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
ggml_sycl_op_sum_rows(ctx, dst);
|
ggml_sycl_op_sum_rows(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
||||||
ggml_sycl_op_argsort(ctx, dst);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void ggml_sycl_set_main_device(const int main_device) try {
|
void ggml_sycl_set_main_device(const int main_device) try {
|
||||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue