ggml : check ne00 <= INT32_MAX in argmax and argsort

This commit is contained in:
slaren 2024-11-21 15:01:09 +01:00
parent 316f3d3116
commit 48f94d41d9

View file

@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_matrix(a));
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order);