make sure some tensors are not reallocated by inserting new temporary nodes depending on them:

output and parameter gradient tensors need to be available at the end of the graph execution

parameter gradient tensors also need to be available before the graph execution because they are set to zero before each optimizer iteration

checkpoint tensors are allocated all together to reduce memory allocator fragmentation

afterwards, in addition to the temporary nodes, we also need to reset the temporary leafs
This commit is contained in:
xaedes 2023-08-14 18:07:16 +02:00
parent 9716eb8ef0
commit 38f4438c32
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1656,19 +1656,36 @@ struct ggml_tensor * llama_build_train_graphs(
} }
if (alloc) { if (alloc) {
// make sure t35 and t36 are not reallocated by inserting new temporary node depending on them // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
struct ggml_tensor * dep = ggml_scale_inplace(ctx, t35, t36); int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes; int n_nodes_before = gb->n_nodes;
ggml_build_forward_expand(gb, dep); struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
// gradient tensors (will be set to zero by ggml_graph_reset)
for (int i = 0; i < gf->n_nodes; ++i) {
if (!gf->grads[i]) continue;
ggml_allocr_alloc(alloc, gf->grads[i]);
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
}
for (int i = 0; i < checkpoints.size(); ++i) {
ggml_allocr_alloc(alloc, checkpoints[i]);
}
int n_leafs_after = gb->n_leafs;
int n_nodes_after = gb->n_nodes; int n_nodes_after = gb->n_nodes;
GGML_ASSERT(n_nodes_after == n_nodes_before + 1);
ggml_allocr_reset(alloc);
ggml_allocr_alloc_graph(alloc, gb); ggml_allocr_alloc_graph(alloc, gb);
// remove the additional node that was insert // remove the additional nodes and leafs
gb->nodes[n_nodes_after-1] = NULL; for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
gb->leafs[i] = NULL;
}
for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
gb->nodes[i] = NULL;
}
gb->n_leafs = n_leafs_before;
gb->n_nodes = n_nodes_before; gb->n_nodes = n_nodes_before;
} }