From 548ec463c68e1586c25d8fcdb1db246b578ca903 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Nov 2023 16:42:19 +0200 Subject: [PATCH] train : allocate grads for gb_tmp --- examples/finetune/finetune.cpp | 4 ++-- examples/train-text-from-scratch/train-text-from-scratch.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 4b10b5c6d..5a6cf22ce 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1771,7 +1771,7 @@ int main(int argc, char ** argv) { gf->order = (enum ggml_cgraph_eval_order) order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, @@ -1804,7 +1804,7 @@ int main(int argc, char ** argv) { gf->order = best_order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index de244680d..f049a3923 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1138,7 +1138,7 @@ int main(int argc, char ** argv) { gf->order = (enum ggml_cgraph_eval_order) order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_train_graphs( &model, alloc, ctx_compute, @@ -1171,7 +1171,7 @@ int main(int argc, char ** argv) { gf->order = best_order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_train_graphs( &model, alloc, ctx_compute,