avoid stack overflow of large cgraphs in test-grad0

This commit is contained in:
xaedes 2023-08-29 19:59:41 +02:00
parent 794bb7ea42
commit 5f0a4e971f
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -251,18 +251,19 @@ static bool check_gradient(
printf("GGML_N_THREADS = %d\n", n_threads);
}
struct ggml_cgraph gf = ggml_build_forward (f);
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
ggml_graph_reset (&gf);
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
// ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
// ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot");
for (int i = 0; i < nargs; ++i) {
const int nelements = ggml_nelements(x[i]);
@ -273,13 +274,13 @@ static bool check_gradient(
const float xp = x0 + eps;
ggml_set_f32_1d(x[i], k, xp);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
const double f0 = ggml_get_f32_1d(f, 0);
ggml_set_f32_1d(x[i], k, xm);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
const double f1 = ggml_get_f32_1d(f, 0);
const double g0 = (f0 - f1)/(2.0*(double) eps);
@ -287,10 +288,10 @@ static bool check_gradient(
ggml_set_f32_1d(x[i], k, x0);
// compute gradient using backward graph
ggml_graph_reset (&gf);
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
const double g1 = ggml_get_f32_1d(x[i]->grad, k);