diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 4b10b5c6d..5a6cf22ce 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1771,7 +1771,7 @@ int main(int argc, char ** argv) { gf->order = (enum ggml_cgraph_eval_order) order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, @@ -1804,7 +1804,7 @@ int main(int argc, char ** argv) { gf->order = best_order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index de244680d..f049a3923 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1138,7 +1138,7 @@ int main(int argc, char ** argv) { gf->order = (enum ggml_cgraph_eval_order) order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_train_graphs( &model, alloc, ctx_compute, @@ -1171,7 +1171,7 @@ int main(int argc, char ** argv) { gf->order = best_order; gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_train_graphs( &model, alloc, ctx_compute,