From 16e819d53ce5bb7025a545bd45b6404c16f3d432 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Nov 2023 19:59:35 +0200 Subject: [PATCH] sync : pass custom graph sizes in training examples --- common/train.cpp | 1 + examples/finetune/finetune.cpp | 13 +++++++------ .../train-text-from-scratch.cpp | 13 +++++++------ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/common/train.cpp b/common/train.cpp index bc15b7a03..964b156b5 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -32,6 +32,7 @@ struct train_state * init_train_state() { state->opt = new struct ggml_opt_context; state->opt->ctx = NULL; state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; state->opt->loss_after = 0.0f; return state; diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 248927966..f8669af41 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) { opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params.print_forward_graph = false; opt->params.print_backward_graph = false; + opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; opt->params.n_threads = params.common.n_threads; opt->params.past = params.common.opt_past; opt->params.delta = params.common.opt_delta; @@ -1768,11 +1769,11 @@ int main(int argc, char ** argv) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { ctx_compute = ggml_init(ctx_compute_params); alloc = ggml_allocr_new_measure(tensor_alignment); - gf = ggml_new_graph(ctx_compute); + gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf->order = (enum ggml_cgraph_eval_order) order; - gb = ggml_new_graph(ctx_compute); + gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx_compute) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, @@ -1801,11 +1802,11 @@ int main(int argc, char ** argv) { mem_compute_data.resize(max_compute_size); ctx_compute = ggml_init(ctx_compute_params); alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); - gf = ggml_new_graph(ctx_compute); + gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf->order = best_order; - gb = ggml_new_graph(ctx_compute); + gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx_compute) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 8e5eb4230..5d1a97b5e 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1006,6 +1006,7 @@ int main(int argc, char ** argv) { opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params.print_forward_graph = false; opt->params.print_backward_graph = false; + opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; opt->params.n_threads = params.common.n_threads; opt->params.past = params.common.opt_past; opt->params.delta = params.common.opt_delta; @@ -1135,11 +1136,11 @@ int main(int argc, char ** argv) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { ctx_compute = ggml_init(ctx_compute_params); alloc = ggml_allocr_new_measure(tensor_alignment); - gf = ggml_new_graph(ctx_compute); + gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf->order = (enum ggml_cgraph_eval_order) order; - gb = ggml_new_graph(ctx_compute); + gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx_compute) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) : NULL; loss = llama_build_train_graphs( &model, alloc, ctx_compute, @@ -1168,11 +1169,11 @@ int main(int argc, char ** argv) { mem_compute_data.resize(max_compute_size); ctx_compute = ggml_init(ctx_compute_params); alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); - gf = ggml_new_graph(ctx_compute); + gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf->order = best_order; - gb = ggml_new_graph(ctx_compute); + gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx_compute) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false) : NULL; loss = llama_build_train_graphs( &model, alloc, ctx_compute,