ggml: fix gradient allocation logic (ggml/966)

* ggml: fix gradient allocation logic

* gradient allocation in ggml_build_backward_expand

* fixup

* fix test-backend-ops grad

* suggestions by slaren

* fix test1.c

* fix legacy opt API

* fix test-grad0

* remove keep arg
This commit is contained in:
Johannes Gäßler 2024-09-29 23:18:02 +02:00 committed by Georgi Gerganov
parent cad341d889
commit 7254cdf7e8
No known key found for this signature in database
GPG key ID: BF970631944C16B7
4 changed files with 490 additions and 1065 deletions

View file

@ -240,12 +240,14 @@ static bool check_gradient(
struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
ggml_build_forward_expand(gf, f);
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx0, gf, gb, false, false);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_reset(gb);
if (f->grad) {
ggml_set_f32(f->grad, 1.0f);
}
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
@ -298,8 +300,10 @@ static bool check_gradient(
ggml_set_f32_1d(x[i], k, x0);
// compute gradient using backward graph
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_reset(gb);
if (f->grad) {
ggml_set_f32(f->grad, 1.0f);
}
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);