sync : pass custom graph sizes in training examples

This commit is contained in:
Georgi Gerganov 2023-11-02 19:59:35 +02:00
parent 815f44e5a3
commit 16e819d53c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 15 additions and 12 deletions

View file

@ -32,6 +32,7 @@ struct train_state * init_train_state() {
state->opt = new struct ggml_opt_context; state->opt = new struct ggml_opt_context;
state->opt->ctx = NULL; state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
state->opt->loss_after = 0.0f; state->opt->loss_after = 0.0f;
return state; return state;

View file

@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false; opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false; opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads; opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past; opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta; opt->params.delta = params.common.opt_delta;
@ -1768,11 +1769,11 @@ int main(int argc, char ** argv) {
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment); alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph(ctx_compute); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL; : NULL;
loss = llama_build_lora_finetune_graphs( loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute, &model, &lora, alloc, ctx_compute,
@ -1801,11 +1802,11 @@ int main(int argc, char ** argv) {
mem_compute_data.resize(max_compute_size); mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph(ctx_compute); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph(ctx_compute); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL; : NULL;
loss = llama_build_lora_finetune_graphs( loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute, &model, &lora, alloc, ctx_compute,

View file

@ -1006,6 +1006,7 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false; opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false; opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads; opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past; opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta; opt->params.delta = params.common.opt_delta;
@ -1135,11 +1136,11 @@ int main(int argc, char ** argv) {
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment); alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph(ctx_compute); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL; : NULL;
loss = llama_build_train_graphs( loss = llama_build_train_graphs(
&model, alloc, ctx_compute, &model, alloc, ctx_compute,
@ -1168,11 +1169,11 @@ int main(int argc, char ** argv) {
mem_compute_data.resize(max_compute_size); mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph(ctx_compute); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph(ctx_compute); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb_tmp = params.common.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute) ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL; : NULL;
loss = llama_build_train_graphs( loss = llama_build_train_graphs(
&model, alloc, ctx_compute, &model, alloc, ctx_compute,