From db38d2bce4223fdfcb355f18c4edfd37a1e500af Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 17 Sep 2023 17:33:11 +0200 Subject: [PATCH] train-text-from-scratch: automatically allocate opt context --- examples/finetune/finetune.cpp | 6 +- .../train-text-from-scratch.cpp | 101 ++++++++++++------ 2 files changed, 68 insertions(+), 39 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index c43d00dfd..caeea9c3f 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1651,6 +1651,7 @@ int main(int argc, char ** argv) { ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora)); } } + opt->iter = train->train_its; print_params(&model.hparams); print_lora_params(&lora.hparams); @@ -1660,7 +1661,7 @@ int main(int argc, char ** argv) { printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs); printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f)); printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f)); - opt->iter = train->train_its; + printf("%s: opt iter %d\n", __func__, opt->iter); if (params.only_write_lora) { save_train_files_data save_data; @@ -1684,9 +1685,6 @@ int main(int argc, char ** argv) { int n_vocab = model.hparams.n_vocab; int n_batch = params.common.n_batch; - printf("%s: opt iter %d\n", __func__, opt->iter); - - printf("used_mem model: %zu bytes\n", ggml_used_mem(lora.ctx) + lora.data.size()); std::vector mem_input_data; std::vector mem_compute_data; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 83e156363..069e460c1 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -975,6 +975,27 @@ static void save_train_files(void * vdata, struct train_state * train) { } } +static int64_t get_parameter_count(struct my_llama_model* model) { + int64_t nx = 0; + nx += ggml_nelements(model->tok_embeddings); + nx += ggml_nelements(model->norm); + nx += ggml_nelements(model->output); + + for (uint32_t i = 0; i < model->layers.size(); ++i) { + auto & layer = model->layers[i]; + nx += ggml_nelements(layer.attention_norm); + nx += ggml_nelements(layer.wq); + nx += ggml_nelements(layer.wk); + nx += ggml_nelements(layer.wv); + nx += ggml_nelements(layer.wo); + nx += ggml_nelements(layer.ffn_norm); + nx += ggml_nelements(layer.w1); + nx += ggml_nelements(layer.w2); + nx += ggml_nelements(layer.w3); + } + return nx; +} + int main(int argc, char ** argv) { struct train_params params = get_default_train_params(); @@ -1007,52 +1028,58 @@ int main(int argc, char ** argv) { model.hparams.rope_freq_base = params.rope_freq_base; model.hparams.rope_freq_scale = params.rope_freq_scale; - print_params(&model.hparams); - - int n_tokens = model.hparams.n_ctx; - int n_vocab = model.hparams.n_vocab; - int n_batch = params.common.n_batch; - struct train_state * train = init_train_state(); struct ggml_opt_context * opt = train->opt; - struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); - opt_params_adam.print_forward_graph = false; - opt_params_adam.print_backward_graph = false; - opt_params_adam.n_threads = params.common.n_threads; - opt_params_adam.past = params.common.opt_past; - opt_params_adam.delta = params.common.opt_delta; - opt_params_adam.max_no_improvement = params.common.opt_max_no_improvement; - opt_params_adam.n_gradient_accumulation = params.common.n_gradient_accumulation; - opt_params_adam.adam.n_iter = params.common.adam_n_iter; - opt_params_adam.adam.sched = 1.0f; - opt_params_adam.adam.alpha = params.common.adam_alpha; - opt_params_adam.adam.decay = params.common.adam_decay; - opt_params_adam.adam.decay_min_ndim = params.common.adam_decay_min_ndim; - opt_params_adam.adam.beta1 = params.common.adam_beta1; - opt_params_adam.adam.beta2 = params.common.adam_beta2; - opt_params_adam.adam.gclip = params.common.adam_gclip; - opt_params_adam.adam.eps_f = params.common.adam_eps_f; - - opt->params = opt_params_adam; + // set opt params from command line + opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + opt->params.print_forward_graph = false; + opt->params.print_backward_graph = false; + opt->params.n_threads = params.common.n_threads; + opt->params.past = params.common.opt_past; + opt->params.delta = params.common.opt_delta; + opt->params.max_no_improvement = params.common.opt_max_no_improvement; + opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation; + opt->params.adam.n_iter = params.common.adam_n_iter; + opt->params.adam.sched = 1.0f; + opt->params.adam.alpha = params.common.adam_alpha; + opt->params.adam.decay = params.common.adam_decay; + opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim; + opt->params.adam.beta1 = params.common.adam_beta1; + opt->params.adam.beta2 = params.common.adam_beta2; + opt->params.adam.gclip = params.common.adam_gclip; + opt->params.adam.eps_f = params.common.adam_eps_f; printf("%s: init model\n", __func__); bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train); - if (!existed) { + if (existed) { + // overwrite last n_ctx with user provided n_ctx + if (params.common.custom_n_ctx) { + model.hparams.n_ctx = params.common.n_ctx; + } + + const bool opt_past_changed = opt->params.past != params.common.opt_past; + + if (opt_past_changed) { + die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value train from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting"); + // need to discard previous optimizer past function value statistics and opt_init with new shapes + // TODO + } + } else { init_model(&model); - } - - opt->params = opt_params_adam; - - opt->iter = train->train_its; - printf("%s: opt iter %d\n", __func__, opt->iter); - - bool from_scratch = !existed; - if (from_scratch) { randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f); + ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&model)); } + opt->iter = train->train_its; + print_params(&model.hparams); + printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its); + printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples); + printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens); + printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs); printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f)); + printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f)); + printf("%s: opt iter %d\n", __func__, opt->iter); // TODO: use std::vector intead of "new" size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb); @@ -1066,6 +1093,10 @@ int main(int argc, char ** argv) { alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); } + int n_tokens = model.hparams.n_ctx; + int n_vocab = model.hparams.n_vocab; + int n_batch = params.common.n_batch; + std::vector train_tokens; std::vector train_samples_begin; std::vector train_samples_size;