diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 88a1c3a50..058325059 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1656,19 +1656,36 @@ struct ggml_tensor * llama_build_train_graphs( } if (alloc) { - // make sure t35 and t36 are not reallocated by inserting new temporary node depending on them - struct ggml_tensor * dep = ggml_scale_inplace(ctx, t35, t36); + // make sure some tensors are not reallocated by inserting new temporary nodes depending on them + int n_leafs_before = gb->n_leafs; int n_nodes_before = gb->n_nodes; - ggml_build_forward_expand(gb, dep); + struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f); + // output tensors + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); + // gradient tensors (will be set to zero by ggml_graph_reset) + for (int i = 0; i < gf->n_nodes; ++i) { + if (!gf->grads[i]) continue; + ggml_allocr_alloc(alloc, gf->grads[i]); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one)); + } + for (int i = 0; i < checkpoints.size(); ++i) { + ggml_allocr_alloc(alloc, checkpoints[i]); + } + int n_leafs_after = gb->n_leafs; int n_nodes_after = gb->n_nodes; - GGML_ASSERT(n_nodes_after == n_nodes_before + 1); - ggml_allocr_reset(alloc); ggml_allocr_alloc_graph(alloc, gb); - // remove the additional node that was insert - gb->nodes[n_nodes_after-1] = NULL; + // remove the additional nodes and leafs + for (int i = n_leafs_before; i < gb->n_leafs; ++i) { + gb->leafs[i] = NULL; + } + for (int i = n_nodes_before; i < gb->n_nodes; ++i) { + gb->nodes[i] = NULL; + } + gb->n_leafs = n_leafs_before; gb->n_nodes = n_nodes_before; }