ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)

This commit is contained in:
Johannes Gäßler 2024-10-03 17:29:59 +02:00 committed by Georgi Gerganov
parent eee39bdc96
commit fabdc3bda3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 389 additions and 8 deletions

View file

@ -5,12 +5,14 @@
#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh"
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh"
@ -2143,6 +2145,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
switch (dst->op) {
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
break;
case GGML_OP_COUNT_EQUAL:
ggml_cuda_count_equal(ctx, dst);
break;
case GGML_OP_REPEAT:
ggml_cuda_op_repeat(ctx, dst);
break;
@ -3073,6 +3081,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
} break;
case GGML_OP_DUP:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
} break;
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
{
return true;
} break;
case GGML_OP_REPEAT:
{
ggml_type src0_type = op->src[0]->type;