From 51bedb847ec21ebcc718015228b68210905fc6e6 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Sat, 1 Feb 2025 09:44:30 +0530 Subject: [PATCH] argmax: move missing function to file and fix function name --- ggml/src/ggml-sycl/argmax.cpp | 2 +- ggml/src/ggml-sycl/argmax.hpp | 2 +- ggml/src/ggml-sycl/argsort.cpp | 8 ++++++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 8 -------- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-sycl/argmax.cpp b/ggml/src/ggml-sycl/argmax.cpp index 573a9dc63..946565f87 100644 --- a/ggml/src/ggml-sycl/argmax.cpp +++ b/ggml/src/ggml-sycl/argmax.cpp @@ -47,7 +47,7 @@ static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, con }); } -void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { +static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-sycl/argmax.hpp b/ggml/src/ggml-sycl/argmax.hpp index 9888e4c08..9093528f2 100644 --- a/ggml/src/ggml-sycl/argmax.hpp +++ b/ggml/src/ggml-sycl/argmax.hpp @@ -3,6 +3,6 @@ #include "common.hpp" -void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_ARGMAX_HPP \ No newline at end of file diff --git a/ggml/src/ggml-sycl/argsort.cpp b/ggml/src/ggml-sycl/argsort.cpp index 74cb0afd6..8047f7d47 100644 --- a/ggml/src/ggml-sycl/argsort.cpp +++ b/ggml/src/ggml-sycl/argsort.cpp @@ -1,5 +1,13 @@ #include "argsort.hpp" +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + template static inline void ggml_sycl_swap(T & a, T & b) { T tmp = a; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0d771d61d..803ea6c23 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2232,14 +2232,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, }); } -static int next_power_of_2(int x) { - int n = 1; - while (n < x) { - n *= 2; - } - return n; -} - static void diag_mask_inf_f32_sycl(const float *x, float *dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past,