From 8e86732cf272b343fde55819f0a3ecaa3df860d9 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Sat, 1 Feb 2025 19:33:52 +0530 Subject: [PATCH] diagmask: move to a separate file --- ggml/src/ggml-sycl/backend.hpp | 1 + ggml/src/ggml-sycl/diagmask.cpp | 53 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/diagmask.hpp | 8 +++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 55 -------------------------------- 4 files changed, 62 insertions(+), 55 deletions(-) create mode 100644 ggml/src/ggml-sycl/diagmask.cpp create mode 100644 ggml/src/ggml-sycl/diagmask.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 24cf492b3..92519caf8 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -34,6 +34,7 @@ #include "argsort.hpp" #include "cpy.hpp" #include "getrows.hpp" +#include "diagmask.hpp" #include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/diagmask.cpp b/ggml/src/ggml-sycl/diagmask.cpp new file mode 100644 index 000000000..821c8c699 --- /dev/null +++ b/ggml/src/ggml-sycl/diagmask.cpp @@ -0,0 +1,53 @@ +#include "diagmask.hpp" +#include + +static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, + const int n_past, const sycl::nd_item<3> & item_ct1) { + const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1); + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + if (col >= ncols) { + return; + } + + const int i = row * ncols + col; + //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i]; + //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; +} + +static void diag_mask_inf_f32_sycl(const float * x, float * dst, const int ncols_x, const int nrows_x, + const int rows_per_channel, const int n_past, queue_ptr stream) { + const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1); + const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE; + const sycl::range<3> block_nums(1, block_num_x, nrows_x); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + diag_mask_inf_f32(x, dst, ncols_x, rows_per_channel, n_past, item_ct1); + }); +} + +inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0); + + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t ne01 = dst->src[0]->ne[1]; + const int nrows0 = ggml_nrows(dst->src[0]); + + const int n_past = ((int32_t *) dst->op_params)[0]; + dpct::queue_ptr main_stream = ctx.stream(); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_diag_mask_inf(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} \ No newline at end of file diff --git a/ggml/src/ggml-sycl/diagmask.hpp b/ggml/src/ggml-sycl/diagmask.hpp new file mode 100644 index 000000000..37954aedc --- /dev/null +++ b/ggml/src/ggml-sycl/diagmask.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_DIAG_MASK +#define GGML_SYCL_DIAG_MASK + +#include "common.hpp" + +void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_DIAG_MASK \ No newline at end of file diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 93f936199..2e8b4852a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1463,24 +1463,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols, } } - -static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past, - const sycl::nd_item<3> &item_ct1) { - const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (col >= ncols) { - return; - } - - const int i = row*ncols + col; - //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i]; - //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU - dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; -} - static void scale_f32(const float * x, float * dst, const float scale, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + @@ -1666,21 +1648,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, }); } -static void diag_mask_inf_f32_sycl(const float *x, float *dst, - const int ncols_x, const int nrows_x, - const int rows_per_channel, const int n_past, - queue_ptr stream) { - const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1); - const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE; - const sycl::range<3> block_nums(1, block_num_x, nrows_x); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - diag_mask_inf_f32(x, dst, ncols_x, - rows_per_channel, n_past, - item_ct1); - }); -} - static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, const struct ggml_tensor *src, int64_t i3, int64_t i2, @@ -1962,24 +1929,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); } -inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer)); - - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t ne01 = dst->src[0]->ne[1]; - const int nrows0 = ggml_nrows(dst->src[0]); - - const int n_past = ((int32_t *) dst->op_params)[0]; - dpct::queue_ptr main_stream = ctx.stream(); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); -} - inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); @@ -2957,10 +2906,6 @@ static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_cpy(ctx, dst->src[0], dst); } -static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_diag_mask_inf(ctx, dst); -} - static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented ggml_sycl_op_rope(ctx, dst);