ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)
This commit is contained in:
parent
eee39bdc96
commit
fabdc3bda3
11 changed files with 389 additions and 8 deletions
|
@ -5,12 +5,14 @@
|
|||
#include "ggml-cuda/common.cuh"
|
||||
#include "ggml-cuda/acc.cuh"
|
||||
#include "ggml-cuda/arange.cuh"
|
||||
#include "ggml-cuda/argmax.cuh"
|
||||
#include "ggml-cuda/argsort.cuh"
|
||||
#include "ggml-cuda/binbcast.cuh"
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
#include "ggml-cuda/count-equal.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
|
@ -2143,6 +2145,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
}
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ARGMAX:
|
||||
ggml_cuda_argmax(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
ggml_cuda_count_equal(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_cuda_op_repeat(ctx, dst);
|
||||
break;
|
||||
|
@ -3073,6 +3081,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return false;
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
{
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_REPEAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue