argmax: move missing function to file and fix function name
This commit is contained in:
parent
a153f1972d
commit
51bedb847e
4 changed files with 10 additions and 10 deletions
|
@ -47,7 +47,7 @@ static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, con
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
|
|
|
@ -3,6 +3,6 @@
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
|
|
||||||
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
#endif // GGML_SYCL_ARGMAX_HPP
|
#endif // GGML_SYCL_ARGMAX_HPP
|
|
@ -1,5 +1,13 @@
|
||||||
#include "argsort.hpp"
|
#include "argsort.hpp"
|
||||||
|
|
||||||
|
static int next_power_of_2(int x) {
|
||||||
|
int n = 1;
|
||||||
|
while (n < x) {
|
||||||
|
n *= 2;
|
||||||
|
}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static inline void ggml_sycl_swap(T & a, T & b) {
|
static inline void ggml_sycl_swap(T & a, T & b) {
|
||||||
T tmp = a;
|
T tmp = a;
|
||||||
|
|
|
@ -2232,14 +2232,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static int next_power_of_2(int x) {
|
|
||||||
int n = 1;
|
|
||||||
while (n < x) {
|
|
||||||
n *= 2;
|
|
||||||
}
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
||||||
const int ncols_x, const int nrows_x,
|
const int ncols_x, const int nrows_x,
|
||||||
const int rows_per_channel, const int n_past,
|
const int rows_per_channel, const int n_past,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue