ggml/examples: add backend support for numerical optimization (ggml/949)

* CUDA eval works

* stochastic gradient descent op

* Adam except decay

* CUDA CROSS_ENTROPY_LOSS_BACK

* CUDA mnist-fc training works

* backend CLI arg

* refactor gguf load

* remove sched from opt_step_adam

* implement l1 regularization (weight decay)

* extra call to add optimizer

* initialize gradients with ggml_graph_reset

* gradient accumulation

* increment iter per eval instead of epoch

* adjust backend interfaces

* fix ggml_graph_reset without backend

* fix ggml graph export/import

* fixup

* rename

* revert ggml_opt changes

* more general CUDA repeat_back

* update documentation, fix CNN

* validation split

* add clarifying comment

* optimize PyTorch training

* adjust buffer size, thread count

* fix 0.0f validation split

* Update examples/mnist/mnist-common.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix gradient accumulation

* tensor flag for accumulators -> tensor hash set

* Update include/ggml.h

Co-authored-by: slaren <slarengh@gmail.com>

* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* fix test prints

* Update src/ggml-backend.c

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* better CUDA support for noncontiguous out_prod

* add comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Johannes Gäßler 2024-09-20 19:04:44 +03:00 committed by Georgi Gerganov
parent a6809c6a2e
commit 424c5d00a9
24 changed files with 883 additions and 129 deletions

View file

@ -240,7 +240,7 @@ static bool check_gradient(
struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
ggml_build_forward_expand(gf, f);
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_build_backward_expand(ctx0, gf, gb, false, false);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);