move train data saving code into callback to unify code of opt_callback

train_params are still different in finetune and train-text-from-scratch, so it can't yet be moved to train.h|cpp
This commit is contained in:
xaedes 2023-09-16 17:50:16 +02:00
parent a8c8907c62
commit ee27333b16
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 157 additions and 107 deletions

View file

@ -1001,3 +1001,8 @@ size_t tokenize_file(
return out_tokens.size(); return out_tokens.size();
} }
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
return replace_str(filename, pattern_it, sit.c_str());
}

View file

@ -133,3 +133,6 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train); bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train); void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
typedef void (*save_train_files_callback)(void * data, struct train_state * train);

View file

@ -1016,17 +1016,15 @@ static bool load_checkpoint_lora_file(const char * filename, struct my_llama_mod
return true; return true;
} }
static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train, const char * pattern_it, int iteration, const char * latest) { static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); printf("%s: saving to %s\n", __func__, filename);
std::string fn = replace_str(filename, pattern_it, sit.c_str());
printf("%s: saving to %s\n", __func__, fn.c_str());
struct gguf_context * fctx = gguf_init_empty(); struct gguf_context * fctx = gguf_init_empty();
save_checkpoint_lora_gguf(fctx, model, lora, train); save_checkpoint_lora_gguf(fctx, model, lora, train);
// write file // write file
const bool only_meta = false; const bool only_meta = false;
gguf_write_to_file(fctx, fn.c_str(), only_meta); gguf_write_to_file(fctx, filename, only_meta);
gguf_free(fctx); gguf_free(fctx);
} }
@ -1139,11 +1137,9 @@ static void write_tensor(struct llama_file * file, struct ggml_tensor * tensor,
file->write_raw(tensor->data, ggml_nbytes(tensor)); file->write_raw(tensor->data, ggml_nbytes(tensor));
} }
static void save_as_llama_lora(struct my_llama_lora * lora, const char * filename, const char * pattern_it, int iteration, const char * latest) { static void save_as_llama_lora(const char * filename, struct my_llama_lora * lora) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); printf("%s: saving to %s\n", __func__, filename);
std::string fn = replace_str(filename, pattern_it, sit.c_str()); struct llama_file file(filename, "wb");
printf("%s: saving to %s\n", __func__, fn.c_str());
struct llama_file file(fn.c_str(), "wb");
if (file.fp == NULL) { if (file.fp == NULL) {
return; return;
} }
@ -1823,25 +1819,49 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
return true; return true;
} }
struct opt_callback_data { struct save_train_files_data {
struct train_params * params; const char * fn_checkpoint_out;
struct train_state * train; const char * fn_lora_out;
const char * pattern_fn_it;
const char * fn_latest;
struct my_llama_model * model; struct my_llama_model * model;
struct my_llama_lora * lora; struct my_llama_lora * lora;
struct llama_context * lctx; };
int last_save_iter;
llama_token * tokens_data; static void save_train_files(void * vdata, struct train_state * train) {
size_t tokens_size; struct save_train_files_data * data = (struct save_train_files_data *) vdata;
size_t * samples_begin;
size_t * samples_size; int64_t iter = train->opt->iter;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size; if (strlen(data->fn_checkpoint_out) > 0) {
size_t samples_count; save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->model, data->lora, train);
struct ggml_tensor * tokens_input; save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->model, data->lora, train);
struct ggml_tensor * target_probs; }
int first_iter; if (strlen(data->fn_lora_out) > 0) {
int64_t last_time; save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->lora);
double millis_per_iter; save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->lora);
}
}
struct opt_callback_data {
struct train_params * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
struct llama_context * lctx;
int last_save_iter;
llama_token * tokens_data;
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs;
int first_iter;
int64_t last_time;
double millis_per_iter;
}; };
static void opt_callback(void * vdata, int accum_step, float * sched) { static void opt_callback(void * vdata, int accum_step, float * sched) {
@ -1881,14 +1901,10 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) { if (data->save_cb) {
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, opt->iter, params->fn_latest); data->save_cb(data->save_data, train);
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, -1, params->fn_latest);
}
if (strlen(params->fn_lora_out) > 0) {
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest);
} }
data->last_save_iter = opt->iter; data->last_save_iter = opt->iter;
} }
@ -2140,10 +2156,17 @@ int main(int argc, char ** argv) {
opt->iter = train->train_its; opt->iter = train->train_its;
if (params.only_write_lora) { if (params.only_write_lora) {
if (strlen(params.fn_lora_out) > 0) { save_train_files_data save_data;
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest); save_data.fn_checkpoint_out = "";
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest); save_data.fn_lora_out = params.fn_lora_out;
} save_data.pattern_fn_it = params.pattern_fn_it;
save_data.fn_latest = params.fn_latest;
save_data.model = &model;
save_data.lora = &lora;
save_train_files(&save_data, train);
free_train_state(train);
ggml_free(lora.ctx); ggml_free(lora.ctx);
llama_free(lctx); llama_free(lctx);
llama_free_model(lmodel); llama_free_model(lmodel);
@ -2323,12 +2346,20 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
save_train_files_data save_data;
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
save_data.fn_lora_out = params.fn_lora_out;
save_data.pattern_fn_it = params.pattern_fn_it;
save_data.fn_latest = params.fn_latest;
save_data.model = &model;
save_data.lora = &lora;
struct opt_callback_data opt_cb_data; struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params; opt_cb_data.params = &params;
opt_cb_data.train = train; opt_cb_data.train = train;
opt_cb_data.model = &model; opt_cb_data.save_cb = &save_train_files;
opt_cb_data.lora = &lora; opt_cb_data.save_data = &save_data;
opt_cb_data.lctx = lctx; opt_cb_data.lctx = lctx;
opt_cb_data.last_save_iter = opt->iter; opt_cb_data.last_save_iter = opt->iter;
opt_cb_data.tokens_data = train_tokens.data(); opt_cb_data.tokens_data = train_tokens.data();
opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.tokens_size = train_tokens.size();
@ -2374,14 +2405,7 @@ int main(int argc, char ** argv) {
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
if (strlen(params.fn_checkpoint_out) > 0) { save_train_files(&save_data, train);
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, opt->iter, params.fn_latest);
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, -1, params.fn_latest);
}
if (strlen(params.fn_lora_out) > 0) {
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest);
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
}
opt_cb_data.last_save_iter = opt->iter; opt_cb_data.last_save_iter = opt->iter;
} }

View file

@ -640,17 +640,15 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo
} }
} }
static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, const char * pattern_it, int iteration, const char * latest) { static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); printf("%s: saving to %s\n", __func__, filename);
std::string fn = replace_str(filename, pattern_it, sit.c_str());
printf("%s: saving to %s\n", __func__, fn.c_str());
struct gguf_context * fctx = gguf_init_empty(); struct gguf_context * fctx = gguf_init_empty();
save_llama_model_gguf(fctx, fn_vocab_model, model); save_llama_model_gguf(fctx, fn_vocab_model, model);
// write file // write file
const bool only_meta = false; const bool only_meta = false;
gguf_write_to_file(fctx, fn.c_str(), only_meta); gguf_write_to_file(fctx, filename, only_meta);
gguf_free(fctx); gguf_free(fctx);
} }
@ -681,17 +679,15 @@ static bool load_checkpoint_file(const char * filename, struct my_llama_model *
return true; return true;
} }
static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train, const char * pattern_it, int iteration, const char * latest) { static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); printf("%s: saving to %s\n", __func__, filename);
std::string fn = replace_str(filename, pattern_it, sit.c_str());
printf("%s: saving to %s\n", __func__, fn.c_str());
struct gguf_context * fctx = gguf_init_empty(); struct gguf_context * fctx = gguf_init_empty();
save_checkpoint_gguf(fctx, fn_vocab_model, model, train); save_checkpoint_gguf(fctx, fn_vocab_model, model, train);
// write file // write file
const bool only_meta = false; const bool only_meta = false;
gguf_write_to_file(fctx, fn.c_str(), only_meta); gguf_write_to_file(fctx, filename, only_meta);
gguf_free(fctx); gguf_free(fctx);
} }
@ -1220,25 +1216,50 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
return true; return true;
} }
struct opt_callback_data { struct save_train_files_data {
struct train_params * params; const char * fn_checkpoint_out;
struct train_state * train; const char * fn_model_out;
const char * fn_vocab_model;
const char * pattern_fn_it;
const char * fn_latest;
struct my_llama_model * model; struct my_llama_model * model;
struct llama_context * lctx; };
int last_save_iter;
llama_token * tokens_data; static void save_train_files(void * vdata, struct train_state * train) {
size_t tokens_size; struct save_train_files_data * data = (struct save_train_files_data *) vdata;
size_t * samples_begin; int64_t iter = train->opt->iter;
size_t * samples_size;
size_t * shuffled_samples_begin; if (strlen(data->fn_checkpoint_out) > 0) {
size_t * shuffled_samples_size; save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model, train);
size_t samples_count; save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model, train);
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_logits; }
struct ggml_tensor * target_probs; if (strlen(data->fn_model_out) > 0) {
int first_iter; save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model);
int64_t last_time; save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model);
double millis_per_iter; }
}
struct opt_callback_data {
struct train_params * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
struct llama_context * lctx;
int last_save_iter;
llama_token * tokens_data;
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_logits;
struct ggml_tensor * target_probs;
int first_iter;
int64_t last_time;
double millis_per_iter;
}; };
static void opt_callback(void * vdata, int accum_step, float * sched) { static void opt_callback(void * vdata, int accum_step, float * sched) {
@ -1278,15 +1299,10 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) { if (data->save_cb) {
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, opt->iter, params->fn_latest); data->save_cb(data->save_data, train);
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, -1, params->fn_latest); }
}
if (strlen(params->fn_model_out) > 0) {
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest);
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest);
}
data->last_save_iter = opt->iter; data->last_save_iter = opt->iter;
} }
@ -1508,14 +1524,23 @@ int main(int argc, char ** argv) {
train_samples_size.size()); train_samples_size.size());
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
save_train_files_data save_data;
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
save_data.fn_model_out = params.fn_model_out;
save_data.fn_vocab_model = params.fn_vocab_model;
save_data.pattern_fn_it = params.pattern_fn_it;
save_data.fn_latest = params.fn_latest;
save_data.model = &model;
struct opt_callback_data opt_cb_data; struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params; opt_cb_data.params = &params;
opt_cb_data.train = train; opt_cb_data.train = train;
opt_cb_data.model = &model; opt_cb_data.save_cb = &save_train_files;
opt_cb_data.lctx = lctx; opt_cb_data.save_data = &save_data;
opt_cb_data.last_save_iter = opt->iter; opt_cb_data.lctx = lctx;
opt_cb_data.tokens_data = train_tokens.data(); opt_cb_data.last_save_iter = opt->iter;
opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.tokens_data = train_tokens.data();
opt_cb_data.tokens_size = train_tokens.size();
opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_begin = train_samples_begin.data();
opt_cb_data.samples_size = train_samples_size.data(); opt_cb_data.samples_size = train_samples_size.data();
opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
@ -1620,22 +1645,15 @@ int main(int argc, char ** argv) {
printf("%s: total training time=%f seconds\n", __func__, dd); printf("%s: total training time=%f seconds\n", __func__, dd);
int new_iters = opt->iter - opt_cb_data.last_save_iter; int new_iters = opt->iter - opt_cb_data.last_save_iter;
train->train_its += new_iters; if (new_iters > 0) {
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_its += new_iters;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
if (params.n_examples > 0) { save_train_files(&save_data, train);
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, opt->iter, params.fn_latest); opt_cb_data.last_save_iter = opt->iter;
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, -1, params.fn_latest);
} }
if (strlen(params.fn_model_out) > 0) {
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, opt->iter, params.fn_latest);
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, -1, params.fn_latest);
}
opt_cb_data.last_save_iter = opt->iter;
if (alloc) { if (alloc) {
ggml_allocr_free(alloc); ggml_allocr_free(alloc);
} }