train-text-from-scratch: automatically allocate opt context

This commit is contained in:
xaedes 2023-09-17 17:33:11 +02:00
parent 9e10fa977e
commit db38d2bce4
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 68 additions and 39 deletions

View file

@ -1651,6 +1651,7 @@ int main(int argc, char ** argv) {
ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
}
}
opt->iter = train->train_its;
print_params(&model.hparams);
print_lora_params(&lora.hparams);
@ -1660,7 +1661,7 @@ int main(int argc, char ** argv) {
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f));
printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
opt->iter = train->train_its;
printf("%s: opt iter %d\n", __func__, opt->iter);
if (params.only_write_lora) {
save_train_files_data save_data;
@ -1684,9 +1685,6 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
printf("%s: opt iter %d\n", __func__, opt->iter);
printf("used_mem model: %zu bytes\n", ggml_used_mem(lora.ctx) + lora.data.size());
std::vector<uint8_t> mem_input_data;
std::vector<uint8_t> mem_compute_data;

View file

@ -975,6 +975,27 @@ static void save_train_files(void * vdata, struct train_state * train) {
}
}
static int64_t get_parameter_count(struct my_llama_model* model) {
int64_t nx = 0;
nx += ggml_nelements(model->tok_embeddings);
nx += ggml_nelements(model->norm);
nx += ggml_nelements(model->output);
for (uint32_t i = 0; i < model->layers.size(); ++i) {
auto & layer = model->layers[i];
nx += ggml_nelements(layer.attention_norm);
nx += ggml_nelements(layer.wq);
nx += ggml_nelements(layer.wk);
nx += ggml_nelements(layer.wv);
nx += ggml_nelements(layer.wo);
nx += ggml_nelements(layer.ffn_norm);
nx += ggml_nelements(layer.w1);
nx += ggml_nelements(layer.w2);
nx += ggml_nelements(layer.w3);
}
return nx;
}
int main(int argc, char ** argv) {
struct train_params params = get_default_train_params();
@ -1007,52 +1028,58 @@ int main(int argc, char ** argv) {
model.hparams.rope_freq_base = params.rope_freq_base;
model.hparams.rope_freq_scale = params.rope_freq_scale;
print_params(&model.hparams);
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
struct train_state * train = init_train_state();
struct ggml_opt_context * opt = train->opt;
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = params.common.n_threads;
opt_params_adam.past = params.common.opt_past;
opt_params_adam.delta = params.common.opt_delta;
opt_params_adam.max_no_improvement = params.common.opt_max_no_improvement;
opt_params_adam.n_gradient_accumulation = params.common.n_gradient_accumulation;
opt_params_adam.adam.n_iter = params.common.adam_n_iter;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.common.adam_alpha;
opt_params_adam.adam.decay = params.common.adam_decay;
opt_params_adam.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
opt_params_adam.adam.beta1 = params.common.adam_beta1;
opt_params_adam.adam.beta2 = params.common.adam_beta2;
opt_params_adam.adam.gclip = params.common.adam_gclip;
opt_params_adam.adam.eps_f = params.common.adam_eps_f;
opt->params = opt_params_adam;
// set opt params from command line
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
opt->params.max_no_improvement = params.common.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
opt->params.adam.n_iter = params.common.adam_n_iter;
opt->params.adam.sched = 1.0f;
opt->params.adam.alpha = params.common.adam_alpha;
opt->params.adam.decay = params.common.adam_decay;
opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
opt->params.adam.beta1 = params.common.adam_beta1;
opt->params.adam.beta2 = params.common.adam_beta2;
opt->params.adam.gclip = params.common.adam_gclip;
opt->params.adam.eps_f = params.common.adam_eps_f;
printf("%s: init model\n", __func__);
bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
if (!existed) {
if (existed) {
// overwrite last n_ctx with user provided n_ctx
if (params.common.custom_n_ctx) {
model.hparams.n_ctx = params.common.n_ctx;
}
const bool opt_past_changed = opt->params.past != params.common.opt_past;
if (opt_past_changed) {
die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value train from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
// need to discard previous optimizer past function value statistics and opt_init with new shapes
// TODO
}
} else {
init_model(&model);
}
opt->params = opt_params_adam;
opt->iter = train->train_its;
printf("%s: opt iter %d\n", __func__, opt->iter);
bool from_scratch = !existed;
if (from_scratch) {
randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&model));
}
opt->iter = train->train_its;
print_params(&model.hparams);
printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f));
printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
printf("%s: opt iter %d\n", __func__, opt->iter);
// TODO: use std::vector<uint8_t> intead of "new"
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
@ -1066,6 +1093,10 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
}
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
std::vector<llama_token> train_tokens;
std::vector<size_t> train_samples_begin;
std::vector<size_t> train_samples_size;