train : allocate grads for gb_tmp

This commit is contained in:
Georgi Gerganov 2023-11-07 16:42:19 +02:00
parent a4de8042ee
commit 548ec463c6
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 4 additions and 4 deletions

View file

@ -1771,7 +1771,7 @@ int main(int argc, char ** argv) {
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL; : NULL;
loss = llama_build_lora_finetune_graphs( loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute, &model, &lora, alloc, ctx_compute,
@ -1804,7 +1804,7 @@ int main(int argc, char ** argv) {
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL; : NULL;
loss = llama_build_lora_finetune_graphs( loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute, &model, &lora, alloc, ctx_compute,

View file

@ -1138,7 +1138,7 @@ int main(int argc, char ** argv) {
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL; : NULL;
loss = llama_build_train_graphs( loss = llama_build_train_graphs(
&model, alloc, ctx_compute, &model, alloc, ctx_compute,
@ -1171,7 +1171,7 @@ int main(int argc, char ** argv) {
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL; : NULL;
loss = llama_build_train_graphs( loss = llama_build_train_graphs(
&model, alloc, ctx_compute, &model, alloc, ctx_compute,