Move sum and sum rows to a separate file

This commit is contained in:
Akarshan Biswas 2025-02-02 18:16:41 +05:30
parent eb466d733a
commit 5c05a3eedc
No known key found for this signature in database
GPG key ID: 52A578A14B32134D
3 changed files with 81 additions and 49 deletions

View file

@ -1528,17 +1528,6 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
} }
} }
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
const sycl::range<3> block_nums(1, nrows, 1);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
k_sum_rows_f32(x, dst, ncols, item_ct1);
});
}
static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
const struct ggml_tensor *src, const struct ggml_tensor *src,
int64_t i3, int64_t i2, int64_t i3, int64_t i2,
@ -1752,34 +1741,6 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
const int64_t ne = ggml_nelements(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
}
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
static bool peer_access_enabled = false; static bool peer_access_enabled = false;
@ -2701,16 +2662,6 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_sum(ctx, dst);
}
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_sum_rows(ctx, dst);
}
void ggml_sycl_set_main_device(const int main_device) try { void ggml_sycl_set_main_device(const int main_device) try {
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) { if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
return; return;

View file

@ -0,0 +1,72 @@
#include "sum.hpp"
static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> & item_ct1) {
const int row = item_ct1.get_group(1);
const int col = item_ct1.get_local_id(2);
float sum = 0.0f;
for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum, item_ct1);
if (col == 0) {
dst[row] = sum;
}
}
static void sum_rows_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
const sycl::range<3> block_nums(1, nrows, 1);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { k_sum_rows_f32(x, dst, ncols, item_ct1); });
}
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const int64_t ne = ggml_nelements(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_sum(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_sum_rows(ctx, dst);
GML_SYCL_DEBUG("call %s done\n", __func__);
}

View file

@ -0,0 +1,9 @@
#ifndef GGML_SYCL_SUM_HPP
#define GGML_SYCL_SUM_HPP
#include "common.hpp"
void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_SUM_HPP