Move sum and sum rows to a separate file
This commit is contained in:
parent
eb466d733a
commit
5c05a3eedc
3 changed files with 81 additions and 49 deletions
|
@ -1528,17 +1528,6 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
|
||||||
const int nrows, queue_ptr stream) {
|
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
||||||
const sycl::range<3> block_nums(1, nrows, 1);
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
|
static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
|
||||||
const struct ggml_tensor *src,
|
const struct ggml_tensor *src,
|
||||||
int64_t i3, int64_t i2,
|
int64_t i3, int64_t i2,
|
||||||
|
@ -1752,34 +1741,6 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
|
|
||||||
|
|
||||||
const int64_t ne = ggml_nelements(dst->src[0]);
|
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
|
||||||
|
|
||||||
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
|
|
||||||
|
|
||||||
const int64_t ncols = dst->src[0]->ne[0];
|
|
||||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
|
||||||
|
|
||||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
||||||
static bool peer_access_enabled = false;
|
static bool peer_access_enabled = false;
|
||||||
|
|
||||||
|
@ -2701,16 +2662,6 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
||||||
ggml_sycl_op_sum(ctx, dst);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
||||||
ggml_sycl_op_sum_rows(ctx, dst);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_sycl_set_main_device(const int main_device) try {
|
void ggml_sycl_set_main_device(const int main_device) try {
|
||||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||||
return;
|
return;
|
||||||
|
|
72
ggml/src/ggml-sycl/sum.cpp
Normal file
72
ggml/src/ggml-sycl/sum.cpp
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
#include "sum.hpp"
|
||||||
|
|
||||||
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> & item_ct1) {
|
||||||
|
const int row = item_ct1.get_group(1);
|
||||||
|
const int col = item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
|
||||||
|
sum += x[row * ncols + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum, item_ct1);
|
||||||
|
|
||||||
|
if (col == 0) {
|
||||||
|
dst[row] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void sum_rows_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, queue_ptr stream) {
|
||||||
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
|
const sycl::range<3> block_nums(1, nrows, 1);
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[intel::reqd_sub_group_size(WARP_SIZE)]] { k_sum_rows_f32(x, dst, ncols, item_ct1); });
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||||
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
|
|
||||||
|
const int64_t ne = ggml_nelements(dst->src[0]);
|
||||||
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
|
||||||
|
} catch (const sycl::exception & exc) {
|
||||||
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||||
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
|
|
||||||
|
const int64_t ncols = dst->src[0]->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||||
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||||
|
} catch (const sycl::exception & exc) {
|
||||||
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_sum(ctx, dst);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_sum_rows(ctx, dst);
|
||||||
|
GML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
9
ggml/src/ggml-sycl/sum.hpp
Normal file
9
ggml/src/ggml-sycl/sum.hpp
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
#ifndef GGML_SYCL_SUM_HPP
|
||||||
|
#define GGML_SYCL_SUM_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_SUM_HPP
|
Loading…
Add table
Add a link
Reference in a new issue