cuda : optimize argmax (#10441)

* cuda : optimize argmax

* remove unused parameter

ggml-ci

* fixup : use full warps

ggml-ci

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* fix ub

* ggml : check ne00 <= INT32_MAX in argmax and argsort

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Diego Devesa 2024-11-21 18:18:50 +01:00 committed by GitHub
parent 1bb30bf28c
commit a5e47592b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 110 additions and 67 deletions

View file

@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_matrix(a));
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order);