From 5c05a3eedc33d28aefef8424137f3351fc53e318 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Sun, 2 Feb 2025 18:16:41 +0530 Subject: [PATCH] Move sum and sum rows to a separate file --- ggml/src/ggml-sycl/ggml-sycl.cpp | 49 ---------------------- ggml/src/ggml-sycl/sum.cpp | 72 ++++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/sum.hpp | 9 ++++ 3 files changed, 81 insertions(+), 49 deletions(-) create mode 100644 ggml/src/ggml-sycl/sum.cpp create mode 100644 ggml/src/ggml-sycl/sum.hpp diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 346a32260..451bb2bae 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1528,17 +1528,6 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( } } -static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, - const int nrows, queue_ptr stream) { - const sycl::range<3> block_dims(1, 1, WARP_SIZE); - const sycl::range<3> block_nums(1, nrows, 1); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { - k_sum_rows_f32(x, dst, ncols, item_ct1); - }); -} - static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, const struct ggml_tensor *src, int64_t i3, int64_t i2, @@ -1752,34 +1741,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer)); - - const int64_t ne = ggml_nelements(dst->src[0]); - dpct::queue_ptr main_stream = ctx.stream(); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); -} - -inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer)); - - const int64_t ncols = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); - dpct::queue_ptr main_stream = ctx.stream(); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); -} - static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { static bool peer_access_enabled = false; @@ -2701,16 +2662,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_sum(ctx, dst); -} - -static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - ggml_sycl_op_sum_rows(ctx, dst); -} - void ggml_sycl_set_main_device(const int main_device) try { if (dpct::get_current_device_id() == static_cast (main_device)) { return; diff --git a/ggml/src/ggml-sycl/sum.cpp b/ggml/src/ggml-sycl/sum.cpp new file mode 100644 index 000000000..be94b1784 --- /dev/null +++ b/ggml/src/ggml-sycl/sum.cpp @@ -0,0 +1,72 @@ +#include "sum.hpp" + +static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> & item_ct1) { + const int row = item_ct1.get_group(1); + const int col = item_ct1.get_local_id(2); + + float sum = 0.0f; + for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum, item_ct1); + + if (col == 0) { + dst[row] = sum; + } +} + +static void sum_rows_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, queue_ptr stream) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + const sycl::range<3> block_nums(1, nrows, 1); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { k_sum_rows_f32(x, dst, ncols, item_ct1); }); +} + +inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0); + + const int64_t ne = ggml_nelements(dst->src[0]); + dpct::queue_ptr main_stream = ctx.stream(); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0); + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + dpct::queue_ptr main_stream = ctx.stream(); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_sum(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_sum_rows(ctx, dst); + GML_SYCL_DEBUG("call %s done\n", __func__); +} diff --git a/ggml/src/ggml-sycl/sum.hpp b/ggml/src/ggml-sycl/sum.hpp new file mode 100644 index 000000000..d1b8e5a7c --- /dev/null +++ b/ggml/src/ggml-sycl/sum.hpp @@ -0,0 +1,9 @@ +#ifndef GGML_SYCL_SUM_HPP +#define GGML_SYCL_SUM_HPP + +#include "common.hpp" + +void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_SUM_HPP