cuda : optimize argmax (#10441)
* cuda : optimize argmax * remove unused parameter ggml-ci * fixup : use full warps ggml-ci * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * fix ub * ggml : check ne00 <= INT32_MAX in argmax and argsort --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
1bb30bf28c
commit
a5e47592b6
5 changed files with 110 additions and 67 deletions
|
@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
GGML_ASSERT(ggml_is_matrix(a));
|
||||
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
|
||||
|
||||
|
@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_sort_order order) {
|
||||
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) order);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue