store optimizer state in training checkpoint and add learning schedule

persistent optimizer state allows to resume training without resetting the optimizer
learning schedule consists of linear warmup ramp followed by cosine decay with restarts
This commit is contained in:
xaedes 2023-05-21 21:36:04 +02:00
parent 37c69435f0
commit 42d9b4cfc2
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -208,6 +208,7 @@ struct my_llama_model {
uint32_t train_its = 0;
uint32_t train_samples = 0;
uint32_t train_tokens = 0;
};
uint32_t get_n_ff(const struct my_llama_hparams* hparams) {
@ -237,6 +238,10 @@ void init_model(struct my_llama_model * model) {
struct ggml_context * ctx = model->ctx;
model->train_its = 0;
model->train_samples = 0;
model->train_tokens = 0;
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
@ -1613,6 +1618,13 @@ enum llama_file_version {
};
void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
if (tensor == NULL) {
file->write_u32(0);
file->write_u32(0);
file->write_u32(GGML_TYPE_F32);
file->seek(-file->tell() & 31, SEEK_CUR);
return;
}
const char * name = ggml_get_name(tensor);
uint32_t name_len = strlen(name);
uint32_t nd = tensor->n_dims;
@ -1629,28 +1641,135 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
uint32_t nd = file->read_u32();
GGML_ASSERT(nd == tensor->n_dims);
uint32_t name_len = file->read_u32();
enum ggml_type type = (enum ggml_type) file->read_u32();
uint32_t name_len = file->read_u32();
enum ggml_type type = (enum ggml_type) file->read_u32();
GGML_ASSERT(type == tensor->type);
uint32_t ne[4];
file->read_raw(ne, sizeof(ne[0]) * nd);
for (int i=0; i<nd; ++i) {
GGML_ASSERT(ne[i] == tensor->ne[i]);
}
std::string name = file->read_string(name_len);
file->seek(-file->tell() & 31, SEEK_CUR);
std::string name = file->read_string(name_len);
GGML_ASSERT(strcmp(ggml_get_name(tensor), name.c_str()) == 0);
file->seek(-file->tell() & 31, SEEK_CUR);
file->read_raw(tensor->data, ggml_nbytes(tensor));
}
void save_model(struct my_llama_model * model, const char * filename) {
void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) {
const uint32_t version = 0;
GGML_ASSERT(opt->nx >= 0);
GGML_ASSERT(opt->iter >= 0);
file->write_u32(version);
file->write_raw(&opt->params, sizeof(opt->params));
file->write_raw(&opt->nx, sizeof(opt->nx));
file->write_raw(&opt->iter, sizeof(opt->iter));
file->write_u32((uint32_t) opt->just_initialized);
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
GGML_ASSERT(opt->adam.x != NULL);
write_tensor(file, opt->adam.x);
write_tensor(file, opt->adam.g1);
write_tensor(file, opt->adam.g2);
write_tensor(file, opt->adam.m);
write_tensor(file, opt->adam.v);
write_tensor(file, opt->adam.mh);
write_tensor(file, opt->adam.vh);
write_tensor(file, opt->adam.pf);
file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
file->write_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement));
} break;
case GGML_OPT_LBFGS:
{
GGML_ASSERT(opt->adam.x != NULL);
write_tensor(file, opt->lbfgs.x);
write_tensor(file, opt->lbfgs.xp);
write_tensor(file, opt->lbfgs.g);
write_tensor(file, opt->lbfgs.gp);
write_tensor(file, opt->lbfgs.d);
write_tensor(file, opt->lbfgs.pf);
write_tensor(file, opt->lbfgs.lmal);
write_tensor(file, opt->lbfgs.lmys);
write_tensor(file, opt->lbfgs.lms);
write_tensor(file, opt->lbfgs.lmy);
file->write_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best));
file->write_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step));
file->write_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j));
file->write_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k));
file->write_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end));
file->write_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement));
} break;
}
}
void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
uint32_t version = file->read_u32();
GGML_ASSERT(version == 0);
file->read_raw(&opt->params, sizeof(opt->params));
file->read_raw(&opt->nx, sizeof(opt->nx));
ggml_opt_init(ctx, opt, opt->params, opt->nx);
file->read_raw(&opt->iter, sizeof(opt->iter));
opt->just_initialized = (bool) file->read_u32();
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
read_tensor(file, opt->adam.x);
read_tensor(file, opt->adam.g1);
read_tensor(file, opt->adam.g2);
read_tensor(file, opt->adam.m);
read_tensor(file, opt->adam.v);
read_tensor(file, opt->adam.mh);
read_tensor(file, opt->adam.vh);
if (opt->adam.pf) { read_tensor(file, opt->adam.pf); }
file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
file->read_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement));
} break;
case GGML_OPT_LBFGS:
{
GGML_ASSERT(opt->adam.x != NULL);
read_tensor(file, opt->lbfgs.x);
read_tensor(file, opt->lbfgs.xp);
read_tensor(file, opt->lbfgs.g);
read_tensor(file, opt->lbfgs.gp);
read_tensor(file, opt->lbfgs.d);
if (opt->lbfgs.pf) { read_tensor(file, opt->lbfgs.pf); }
read_tensor(file, opt->lbfgs.lmal);
read_tensor(file, opt->lbfgs.lmys);
read_tensor(file, opt->lbfgs.lms);
read_tensor(file, opt->lbfgs.lmy);
file->read_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best));
file->read_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step));
file->read_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j));
file->read_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k));
file->read_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end));
file->read_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement));
} break;
}
}
void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename) {
struct llama_file file(filename, "wb");
if (file.fp == NULL) {
return;
}
const uint32_t magic = 'ggcp';
const uint32_t version = 0;
file.write_u32(magic);
file.write_u32(version);
file.write_u32(model->train_its);
file.write_u32(model->train_samples);
file.write_u32(model->train_tokens);
file.write_u32(model->hparams.n_vocab);
file.write_u32(model->hparams.n_embd);
file.write_u32(model->hparams.n_mult);
@ -1675,23 +1794,35 @@ void save_model(struct my_llama_model * model, const char * filename) {
write_tensor(&file, layer.w2);
write_tensor(&file, layer.w3);
}
write_opt_context(&file, opt);
}
bool load_model(struct my_llama_model * model, const char * filename, bool init) {
bool load_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename, bool init) {
struct llama_file file(filename, "rb");
uint32_t magic;
uint32_t version;
uint32_t train_its = 0;
uint32_t train_samples = 0;
uint32_t train_tokens = 0;
if (file.fp) {
printf("%s: Loading model from '%s'.\n", __func__, filename);
model->train_its = file.read_u32();
model->train_samples = file.read_u32();
magic = file.read_u32();
GGML_ASSERT(magic == 'ggcp');
version = file.read_u32();
GGML_ASSERT(version == 0);
train_its = file.read_u32();
train_samples = file.read_u32();
train_tokens = file.read_u32();
model->hparams.n_vocab = file.read_u32();
model->hparams.n_embd = file.read_u32();
model->hparams.n_mult = file.read_u32();
model->hparams.n_head = file.read_u32();
model->hparams.n_layer = file.read_u32();
model->hparams.n_rot = file.read_u32();
printf("%s: Training iterations: %u.\n", __func__, model->train_its);
printf("%s: Training samples: %u.\n", __func__, model->train_samples);
print_params(&model->hparams);
}
@ -1699,6 +1830,16 @@ bool load_model(struct my_llama_model * model, const char * filename, bool init)
init_model(model);
}
if (file.fp) {
model->train_its = train_its;
model->train_samples = train_samples;
model->train_tokens = train_tokens;
}
printf("%s: Training iterations: %u.\n", __func__, model->train_its);
printf("%s: Training samples: %u.\n", __func__, model->train_samples);
printf("%s: Training tokens: %u.\n", __func__, model->train_tokens);
if (file.fp) {
read_tensor(&file, model->tok_embeddings);
read_tensor(&file, model->norm);
@ -1717,11 +1858,30 @@ bool load_model(struct my_llama_model * model, const char * filename, bool init)
read_tensor(&file, layer.w2);
read_tensor(&file, layer.w3);
}
read_opt_context(&file, model->ctx, opt);
}
return (file.fp != NULL);
}
float cosine_decay(const int decay_steps, const float alpha, int step) {
if (step > decay_steps) {
step = decay_steps;
}
const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
const float decay = (1 - alpha)*cosine_decay + alpha;
return decay;
}
float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult) {
while (step > decay_steps) {
step -= decay_steps;
decay_steps = (int) restart_step_mult * decay_steps;
}
return cosine_decay(decay_steps, alpha, step);
}
int main(int argc, char ** argv) {
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
const char * default_train = "shakespeare.txt";
@ -1795,16 +1955,55 @@ int main(int argc, char ** argv) {
my_llama_sampler sampler;
int n_threads = 6;
bool use_adam = true;
int warmup = 100;
int cos_decay_steps = 1000;
float cos_decay_restart = 1.1f;
float cos_decay_alpha = 0.0f;
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
memset(opt, 0, sizeof(struct ggml_opt_context));
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = n_threads;
opt_params_adam.adam.n_iter = 16;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = 1e-3;
opt_params_adam.adam.decay = 1e-3;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = n_threads;
opt_params_lbfgs.lbfgs.n_iter = 16;
opt->ctx = model.ctx;
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
printf("%s: init model\n", __func__);
bool existed = load_model(&model, fn_chkpt_in, true);
bool from_scratch = !existed;
bool existed = load_checkpoint(&model, opt, fn_chkpt_in, true);
set_param_model(&model);
opt->iter = model.train_its;
printf("%s: opt iter %d\n", __func__, opt->iter);
bool from_scratch = !existed;
if (from_scratch) {
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
}
init_kv_cache(&kv_self, &model, n_batch);
init_kv_cache(&kv_self, &model, 1);
// init_kv_cache(&kv_self, &model, n_batch);
init_sampler(&sampler, lctx);
printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
// ggml_print_tensor_objects(model.ctx);
size_t compute_size = 1024ll*1024ll*1024ll*32ll;
uint8_t * compute_addr = new uint8_t[compute_size];
@ -1853,7 +2052,7 @@ int main(int argc, char ** argv) {
int n_past = 0;
ggml_cgraph gf = {};
gf.n_threads = 6;
gf.n_threads = n_threads;
get_example_targets_batch(ctx0, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
@ -1875,30 +2074,20 @@ int main(int argc, char ** argv) {
float error_before_opt = ggml_get_f32_1d(e, 0);
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = gf.n_threads;
opt_params_adam.adam.n_iter = 16;
opt_params_adam.adam.alpha = 1e-4;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = gf.n_threads;
opt_params_lbfgs.lbfgs.n_iter = 16;
opt->params.adam.sched = (opt->iter < warmup)
? (float) opt->iter / (float) warmup
: cosine_decay_restart(cos_decay_steps, cos_decay_alpha, opt->iter - warmup, cos_decay_restart);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
bool use_adam = true;
if (use_adam) {
ggml_opt(ctx0, opt_params_adam, e);
} else {
ggml_opt(ctx0, opt_params_lbfgs, e);
}
// ggml_opt(ctx0, opt->params, e);
ggml_opt_resume(ctx0, opt, e);
size_t used_mem_after_opt = ggml_used_mem(ctx0);
model.train_its += use_adam ? opt_params_adam.adam.n_iter : opt_params_lbfgs.lbfgs.n_iter;
model.train_its = opt->iter;
// model.train_its += use_adam ? opt_params_adam.adam.n_iter : opt_params_lbfgs.lbfgs.n_iter;
model.train_samples += n_batch;
model.train_tokens += n_batch * n_tokens;
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);
@ -1909,7 +2098,7 @@ int main(int argc, char ** argv) {
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
if (ex % 1 == 0) {
printf("Example %d\n", ex);
printf("Example %d, opt iter %d\n", ex, opt->iter);
printf("error_before_opt: %.6f\n", error_before_opt);
printf("error_after_opt: %.6f\n", error_after_opt);
}
@ -1943,7 +2132,7 @@ int main(int argc, char ** argv) {
ggml_free(ctx0);
}
save_model(&model, fn_chkpt_out);
save_checkpoint(&model, opt, fn_chkpt_out);
{
int n_gen = 1024;