argmax: move missing function to file and fix function name
This commit is contained in:
parent
a153f1972d
commit
51bedb847e
4 changed files with 10 additions and 10 deletions
|
@ -47,7 +47,7 @@ static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, con
|
|||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||
static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
|
|
|
@ -3,6 +3,6 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ARGMAX_HPP
|
|
@ -1,5 +1,13 @@
|
|||
#include "argsort.hpp"
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
int n = 1;
|
||||
while (n < x) {
|
||||
n *= 2;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline void ggml_sycl_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
|
|
|
@ -2232,14 +2232,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
|||
});
|
||||
}
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
int n = 1;
|
||||
while (n < x) {
|
||||
n *= 2;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
||||
const int ncols_x, const int nrows_x,
|
||||
const int rows_per_channel, const int n_past,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue