argmax: move missing function to file and fix function name

This commit is contained in:
Akarshan Biswas 2025-02-01 09:44:30 +05:30
parent a153f1972d
commit 51bedb847e
No known key found for this signature in database
GPG key ID: 52A578A14B32134D
4 changed files with 10 additions and 10 deletions

View file

@ -47,7 +47,7 @@ static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, con
}); });
} }
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try { static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);

View file

@ -3,6 +3,6 @@
#include "common.hpp" #include "common.hpp"
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_ARGMAX_HPP #endif // GGML_SYCL_ARGMAX_HPP

View file

@ -1,5 +1,13 @@
#include "argsort.hpp" #include "argsort.hpp"
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
template <typename T> template <typename T>
static inline void ggml_sycl_swap(T & a, T & b) { static inline void ggml_sycl_swap(T & a, T & b) {
T tmp = a; T tmp = a;

View file

@ -2232,14 +2232,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
}); });
} }
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
static void diag_mask_inf_f32_sycl(const float *x, float *dst, static void diag_mask_inf_f32_sycl(const float *x, float *dst,
const int ncols_x, const int nrows_x, const int ncols_x, const int nrows_x,
const int rows_per_channel, const int n_past, const int rows_per_channel, const int n_past,