From 95a09ab5056efbf5c69bab16fcb4966827305918 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Sat, 1 Feb 2025 09:22:25 +0530 Subject: [PATCH] ARGMAX: move to a separate file --- ggml/src/ggml-sycl/argmax.cpp | 73 ++++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/argmax.hpp | 8 ++++ ggml/src/ggml-sycl/backend.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 68 ----------------------------- 4 files changed, 82 insertions(+), 68 deletions(-) create mode 100644 ggml/src/ggml-sycl/argmax.cpp create mode 100644 ggml/src/ggml-sycl/argmax.hpp diff --git a/ggml/src/ggml-sycl/argmax.cpp b/ggml/src/ggml-sycl/argmax.cpp new file mode 100644 index 000000000..573a9dc63 --- /dev/null +++ b/ggml/src/ggml-sycl/argmax.cpp @@ -0,0 +1,73 @@ +#include "argmax.hpp" + +static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, queue_ptr stream) { + const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); + const sycl::range<3> block_nums(1, nrows, 1); + const size_t shared_mem = 256 * sizeof(float); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor shared_data(sycl::range<1>(shared_mem / sizeof(float)), cgh); + sycl::local_accessor shared_indices(sycl::range<1>(shared_mem / sizeof(float)), cgh); + + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + const int tid = item_ct1.get_local_id(2); + const int row = item_ct1.get_global_id(1); + + float max_val = -INFINITY; + int max_idx = -1; + + for (int col = tid; col < ncols; col += 256) { + float val = x[row * ncols + col]; + if (val > max_val) { + max_val = val; + max_idx = col; + } + } + + shared_data[tid] = max_val; + shared_indices[tid] = max_idx; + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int stride = 256 / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + float val1 = shared_data[tid]; + float val2 = shared_data[tid + stride]; + if (val2 > val1) { + shared_data[tid] = val2; + shared_indices[tid] = shared_indices[tid + stride]; + } + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + if (tid == 0) { + dst[row] = shared_indices[0]; + } + }); + }); +} + +void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + + dpct::queue_ptr main_stream = ctx.stream(); + const float * src0_dd = static_cast(dst->src[0]->data); + int32_t * dst_dd = static_cast(dst->data); + argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_argmax(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} \ No newline at end of file diff --git a/ggml/src/ggml-sycl/argmax.hpp b/ggml/src/ggml-sycl/argmax.hpp new file mode 100644 index 000000000..9888e4c08 --- /dev/null +++ b/ggml/src/ggml-sycl/argmax.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_ARGMAX_HPP +#define GGML_SYCL_ARGMAX_HPP + +#include "common.hpp" + +void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_ARGMAX_HPP \ No newline at end of file diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index cdb89e392..05bc85ded 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -30,6 +30,7 @@ #include "outprod.hpp" #include "element_wise.hpp" #include "binbcast.hpp" +#include "argmax.hpp" #include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f618fef80..f9ea4258e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2347,58 +2347,6 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, } } -static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, - const int nrows, queue_ptr stream) { - const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); - const sycl::range<3> block_nums(1, nrows, 1); - const size_t shared_mem = 256 * sizeof(float); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor shared_data( - sycl::range<1>(shared_mem/sizeof(float)), cgh); - sycl::local_accessor shared_indices( - sycl::range<1>(shared_mem/sizeof(float)), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - const int tid = item_ct1.get_local_id(2); - const int row = item_ct1.get_global_id(1); - - float max_val = -INFINITY; - int max_idx = -1; - - for (int col = tid; col < ncols; col += 256) { - float val = x[row * ncols + col]; - if (val > max_val) { - max_val = val; - max_idx = col; - } - } - - shared_data[tid] = max_val; - shared_indices[tid] = max_idx; - item_ct1.barrier(sycl::access::fence_space::local_space); - - for (int stride = 256/2; stride > 0; stride >>= 1) { - if (tid < stride) { - float val1 = shared_data[tid]; - float val2 = shared_data[tid + stride]; - if (val2 > val1) { - shared_data[tid] = val2; - shared_indices[tid] = shared_indices[tid + stride]; - } - } - item_ct1.barrier(sycl::access::fence_space::local_space); - } - - - if (tid == 0) { - dst[row] = shared_indices[0]; - } - }); - }); -} static void diag_mask_inf_f32_sycl(const float *x, float *dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, @@ -2746,22 +2694,6 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream); } -inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(ggml_is_contiguous(dst->src[0])); - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_I32); - - const int64_t ncols = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); - - dpct::queue_ptr main_stream = ctx.stream(); - const float * src0_dd = static_cast(dst->src[0]->data); - int32_t * dst_dd = static_cast(dst->data); - - argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); -} - inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);