save train trained model to checkpoint and load model to be trained from checkpoint

This commit is contained in:
xaedes 2023-05-17 13:49:32 +02:00
parent d328472f16
commit b241b9cb6c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -204,6 +204,9 @@ struct my_llama_model {
struct ggml_tensor * output;
std::vector<my_llama_layer> layers;
uint32_t train_its = 0;
uint32_t train_samples = 0;
};
uint32_t get_n_ff(const struct my_llama_hparams* hparams) {
@ -1124,11 +1127,12 @@ struct llama_file {
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
size = 0;
} else {
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
@ -1355,18 +1359,135 @@ void set_logits_masked(struct ggml_tensor * logits, std::vector<bool>& mask, flo
}
}
enum llama_file_version {
LLAMA_FILE_VERSION_GGML,
LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab
LLAMA_FILE_VERSION_GGJT_V1, // added padding
LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
};
void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
const char * name = ggml_get_name(tensor);
uint32_t name_len = strlen(name);
uint32_t nd = tensor->n_dims;
uint32_t ne[4] = { tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3] };
file->write_u32(nd);
file->write_u32(name_len);
file->write_u32(tensor->type);
file->write_raw(ne, sizeof(ne[0]) * nd);
file->write_raw(name, name_len);
file->seek(-file->tell() & 31, SEEK_CUR);
file->write_raw(tensor->data, ggml_nbytes(tensor));
}
void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
uint32_t nd = file->read_u32();
GGML_ASSERT(nd == tensor->n_dims);
uint32_t name_len = file->read_u32();
enum ggml_type type = (enum ggml_type) file->read_u32();
GGML_ASSERT(type == tensor->type);
uint32_t ne[4];
file->read_raw(ne, sizeof(ne[0]) * nd);
for (int i=0; i<nd; ++i) {
GGML_ASSERT(ne[i] == tensor->ne[i]);
}
std::string name = file->read_string(name_len);
file->seek(-file->tell() & 31, SEEK_CUR);
GGML_ASSERT(strcmp(ggml_get_name(tensor), name.c_str()) == 0);
file->read_raw(tensor->data, ggml_nbytes(tensor));
}
void save_model(struct my_llama_model * model, const char * filename) {
struct llama_file file(filename, "wb");
if (file.fp == NULL) {
return;
}
file.write_u32(model->train_its);
file.write_u32(model->train_samples);
file.write_u32(model->hparams.n_vocab);
file.write_u32(model->hparams.n_embd);
file.write_u32(model->hparams.n_mult);
file.write_u32(model->hparams.n_head);
file.write_u32(model->hparams.n_layer);
file.write_u32(model->hparams.n_rot);
write_tensor(&file, model->tok_embeddings);
write_tensor(&file, model->norm);
write_tensor(&file, model->output);
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
write_tensor(&file, layer.attention_norm);
write_tensor(&file, layer.wq);
write_tensor(&file, layer.wk);
write_tensor(&file, layer.wv);
write_tensor(&file, layer.wo);
write_tensor(&file, layer.ffn_norm);
write_tensor(&file, layer.w1);
write_tensor(&file, layer.w2);
write_tensor(&file, layer.w3);
}
}
void load_model(struct my_llama_model * model, const char * filename, bool init) {
struct llama_file file(filename, "rb");
if (file.fp) {
printf("%s: Loading model from '%s'.\n", __func__, filename);
model->train_its = file.read_u32();
model->train_samples = file.read_u32();
model->hparams.n_vocab = file.read_u32();
model->hparams.n_embd = file.read_u32();
model->hparams.n_mult = file.read_u32();
model->hparams.n_head = file.read_u32();
model->hparams.n_layer = file.read_u32();
model->hparams.n_rot = file.read_u32();
printf("%s: Training iterations: %u.\n", __func__, model->train_its);
printf("%s: Training samples: %u.\n", __func__, model->train_samples);
print_params(&model->hparams);
}
if (init) {
init_model(model);
}
if (file.fp) {
read_tensor(&file, model->tok_embeddings);
read_tensor(&file, model->norm);
read_tensor(&file, model->output);
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
read_tensor(&file, layer.attention_norm);
read_tensor(&file, layer.wq);
read_tensor(&file, layer.wk);
read_tensor(&file, layer.wv);
read_tensor(&file, layer.wo);
read_tensor(&file, layer.ffn_norm);
read_tensor(&file, layer.w1);
read_tensor(&file, layer.w2);
read_tensor(&file, layer.w3);
}
}
}
int main(int argc, char ** argv) {
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
const char * default_train = "shakespeare.txt";
const char * default_argv[3] = {argv[0], default_model, default_train};
const char * default_checkpoint = "checkpoint.bin";
const char * default_argv[4] = {argv[0], default_model, default_train, default_checkpoint};
if (argc < 3) {
if (argc < 4) {
fprintf(stderr, "usage: %s model training_data\n", argv[0]);
//return 1;
}
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
const char * fn_chkpt = (argc >= 4) ? argv[3] : default_argv[3];
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
@ -1420,7 +1541,7 @@ int main(int argc, char ** argv) {
my_llama_sampler sampler;
printf("%s: init model\n", __func__);
init_model(&model);
load_model(&model, fn_chkpt, true);
set_param_model(&model);
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
init_kv_cache(&kv_self, &model, n_batch);
@ -1498,8 +1619,16 @@ int main(int argc, char ** argv) {
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = gf.n_threads;
opt_params_lbfgs.lbfgs.n_iter = 16;
ggml_opt(ctx0, opt_params_adam, e);
// ggml_opt(ctx0, opt_params_lbfgs, e);
bool use_adam = true;
if (use_adam) {
ggml_opt(ctx0, opt_params_adam, e);
} else {
ggml_opt(ctx0, opt_params_lbfgs, e);
}
model.train_its += use_adam ? opt_params_adam.adam.n_iter : opt_params_lbfgs.lbfgs.n_iter;
model.train_samples += n_batch;
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);
@ -1541,6 +1670,8 @@ int main(int argc, char ** argv) {
ggml_free(ctx0);
}
save_model(&model, fn_chkpt);
{
int n_gen = 128;
int sample_ctx = n_tokens - n_tokens/8;