From ed4319e1a78e38777a3d0174f667829d9c0cc271 Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 28 Jul 2023 23:08:11 +0200 Subject: [PATCH] add and use function ggml_build_backward_expand to avoid stack overflows with large maximum number of nodes GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); --- .../train-text-from-scratch.cpp | 16 ++++++++++------ ggml.c | 10 ++++++---- ggml.h | 3 ++- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 075e0307f..61def445e 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -3957,12 +3957,14 @@ int main(int argc, char ** argv) { logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch); loss = cross_entropy_loss(ctx0, logits, target_probs); ggml_build_forward_expand(gf, loss); - *gb = ggml_build_backward(ctx0, gf, true); + *gb = *gf; + ggml_build_backward_expand(ctx0, gf, gb, true); } else { logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch); loss = cross_entropy_loss(ctx0, logits, target_probs); ggml_build_forward_expand(gf, loss); - *gb = ggml_build_backward(ctx0, gf, true); + *gb = *gf; + ggml_build_backward_expand(ctx0, gf, gb, true); } ggml_graph_compute_helper(work_buffer, gf, params.n_threads); @@ -4070,13 +4072,15 @@ int main(int argc, char ** argv) { }; struct ggml_context * ctx0 = ggml_init(cparams); - ggml_cgraph gf = {}; + struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + memset(gfbuf->data, 0, ggml_nbytes(gfbuf)); + struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; int n_past = 0; - struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); + struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past); - ggml_build_forward_expand(&gf, logits); - ggml_graph_compute_helper(work_buffer, &gf, params.n_threads); + ggml_build_forward_expand(gf, logits); + ggml_graph_compute_helper(work_buffer, gf, params.n_threads); //struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); //struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); diff --git a/ggml.c b/ggml.c index 19a194beb..92717f0aa 100644 --- a/ggml.c +++ b/ggml.c @@ -15787,9 +15787,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { return result; } -struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { - struct ggml_cgraph result = *gf; - +void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { GGML_ASSERT(gf->n_nodes > 0); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph @@ -15818,10 +15816,14 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg if (node->is_param) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_expand(&result, node->grad); + ggml_build_forward_expand(gb, node->grad); } } +} +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { + struct ggml_cgraph result = *gf; + ggml_build_backward_expand(ctx, gf, &result, keep); return result; } diff --git a/ggml.h b/ggml.h index 460976468..8f51f5d22 100644 --- a/ggml.h +++ b/ggml.h @@ -1403,7 +1403,8 @@ extern "C" { struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);