move train data saving code into callback to unify code of opt_callback
train_params are still different in finetune and train-text-from-scratch, so it can't yet be moved to train.h|cpp
This commit is contained in:
parent
a8c8907c62
commit
ee27333b16
4 changed files with 157 additions and 107 deletions
|
@ -1001,3 +1001,8 @@ size_t tokenize_file(
|
|||
|
||||
return out_tokens.size();
|
||||
}
|
||||
|
||||
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
|
||||
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||
return replace_str(filename, pattern_it, sit.c_str());
|
||||
}
|
||||
|
|
|
@ -133,3 +133,6 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
|
|||
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
|
||||
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
|
||||
|
||||
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
|
||||
|
||||
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
|
||||
|
|
|
@ -1016,17 +1016,15 @@ static bool load_checkpoint_lora_file(const char * filename, struct my_llama_mod
|
|||
return true;
|
||||
}
|
||||
|
||||
static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train, const char * pattern_it, int iteration, const char * latest) {
|
||||
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||
static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
|
||||
printf("%s: saving to %s\n", __func__, filename);
|
||||
struct gguf_context * fctx = gguf_init_empty();
|
||||
|
||||
save_checkpoint_lora_gguf(fctx, model, lora, train);
|
||||
|
||||
// write file
|
||||
const bool only_meta = false;
|
||||
gguf_write_to_file(fctx, fn.c_str(), only_meta);
|
||||
gguf_write_to_file(fctx, filename, only_meta);
|
||||
gguf_free(fctx);
|
||||
}
|
||||
|
||||
|
@ -1139,11 +1137,9 @@ static void write_tensor(struct llama_file * file, struct ggml_tensor * tensor,
|
|||
file->write_raw(tensor->data, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
static void save_as_llama_lora(struct my_llama_lora * lora, const char * filename, const char * pattern_it, int iteration, const char * latest) {
|
||||
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||
struct llama_file file(fn.c_str(), "wb");
|
||||
static void save_as_llama_lora(const char * filename, struct my_llama_lora * lora) {
|
||||
printf("%s: saving to %s\n", __func__, filename);
|
||||
struct llama_file file(filename, "wb");
|
||||
if (file.fp == NULL) {
|
||||
return;
|
||||
}
|
||||
|
@ -1823,25 +1819,49 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
|||
return true;
|
||||
}
|
||||
|
||||
struct opt_callback_data {
|
||||
struct train_params * params;
|
||||
struct train_state * train;
|
||||
struct save_train_files_data {
|
||||
const char * fn_checkpoint_out;
|
||||
const char * fn_lora_out;
|
||||
const char * pattern_fn_it;
|
||||
const char * fn_latest;
|
||||
struct my_llama_model * model;
|
||||
struct my_llama_lora * lora;
|
||||
struct llama_context * lctx;
|
||||
int last_save_iter;
|
||||
llama_token * tokens_data;
|
||||
size_t tokens_size;
|
||||
size_t * samples_begin;
|
||||
size_t * samples_size;
|
||||
size_t * shuffled_samples_begin;
|
||||
size_t * shuffled_samples_size;
|
||||
size_t samples_count;
|
||||
struct ggml_tensor * tokens_input;
|
||||
struct ggml_tensor * target_probs;
|
||||
int first_iter;
|
||||
int64_t last_time;
|
||||
double millis_per_iter;
|
||||
};
|
||||
|
||||
static void save_train_files(void * vdata, struct train_state * train) {
|
||||
struct save_train_files_data * data = (struct save_train_files_data *) vdata;
|
||||
|
||||
int64_t iter = train->opt->iter;
|
||||
|
||||
if (strlen(data->fn_checkpoint_out) > 0) {
|
||||
save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->model, data->lora, train);
|
||||
save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->model, data->lora, train);
|
||||
}
|
||||
if (strlen(data->fn_lora_out) > 0) {
|
||||
save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->lora);
|
||||
save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->lora);
|
||||
}
|
||||
}
|
||||
|
||||
struct opt_callback_data {
|
||||
struct train_params * params;
|
||||
struct train_state * train;
|
||||
save_train_files_callback save_cb;
|
||||
void * save_data;
|
||||
struct llama_context * lctx;
|
||||
int last_save_iter;
|
||||
llama_token * tokens_data;
|
||||
size_t tokens_size;
|
||||
size_t * samples_begin;
|
||||
size_t * samples_size;
|
||||
size_t * shuffled_samples_begin;
|
||||
size_t * shuffled_samples_size;
|
||||
size_t samples_count;
|
||||
struct ggml_tensor * tokens_input;
|
||||
struct ggml_tensor * target_probs;
|
||||
int first_iter;
|
||||
int64_t last_time;
|
||||
double millis_per_iter;
|
||||
};
|
||||
|
||||
static void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||
|
@ -1881,14 +1901,10 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
|
|||
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
|
||||
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
|
||||
|
||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, -1, params->fn_latest);
|
||||
}
|
||||
if (strlen(params->fn_lora_out) > 0) {
|
||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||
if (data->save_cb) {
|
||||
data->save_cb(data->save_data, train);
|
||||
}
|
||||
|
||||
data->last_save_iter = opt->iter;
|
||||
}
|
||||
|
||||
|
@ -2140,10 +2156,17 @@ int main(int argc, char ** argv) {
|
|||
opt->iter = train->train_its;
|
||||
|
||||
if (params.only_write_lora) {
|
||||
if (strlen(params.fn_lora_out) > 0) {
|
||||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||
}
|
||||
save_train_files_data save_data;
|
||||
save_data.fn_checkpoint_out = "";
|
||||
save_data.fn_lora_out = params.fn_lora_out;
|
||||
save_data.pattern_fn_it = params.pattern_fn_it;
|
||||
save_data.fn_latest = params.fn_latest;
|
||||
save_data.model = &model;
|
||||
save_data.lora = &lora;
|
||||
|
||||
save_train_files(&save_data, train);
|
||||
|
||||
free_train_state(train);
|
||||
ggml_free(lora.ctx);
|
||||
llama_free(lctx);
|
||||
llama_free_model(lmodel);
|
||||
|
@ -2323,12 +2346,20 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("%s: begin training\n", __func__);
|
||||
|
||||
save_train_files_data save_data;
|
||||
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
|
||||
save_data.fn_lora_out = params.fn_lora_out;
|
||||
save_data.pattern_fn_it = params.pattern_fn_it;
|
||||
save_data.fn_latest = params.fn_latest;
|
||||
save_data.model = &model;
|
||||
save_data.lora = &lora;
|
||||
|
||||
struct opt_callback_data opt_cb_data;
|
||||
opt_cb_data.params = ¶ms;
|
||||
opt_cb_data.train = train;
|
||||
opt_cb_data.model = &model;
|
||||
opt_cb_data.lora = &lora;
|
||||
opt_cb_data.lctx = lctx;
|
||||
opt_cb_data.params = ¶ms;
|
||||
opt_cb_data.train = train;
|
||||
opt_cb_data.save_cb = &save_train_files;
|
||||
opt_cb_data.save_data = &save_data;
|
||||
opt_cb_data.lctx = lctx;
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
opt_cb_data.tokens_data = train_tokens.data();
|
||||
opt_cb_data.tokens_size = train_tokens.size();
|
||||
|
@ -2374,14 +2405,7 @@ int main(int argc, char ** argv) {
|
|||
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
|
||||
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
|
||||
|
||||
if (strlen(params.fn_checkpoint_out) > 0) {
|
||||
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, -1, params.fn_latest);
|
||||
}
|
||||
if (strlen(params.fn_lora_out) > 0) {
|
||||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||
}
|
||||
save_train_files(&save_data, train);
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
}
|
||||
|
||||
|
|
|
@ -640,17 +640,15 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo
|
|||
}
|
||||
}
|
||||
|
||||
static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, const char * pattern_it, int iteration, const char * latest) {
|
||||
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||
static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
|
||||
printf("%s: saving to %s\n", __func__, filename);
|
||||
struct gguf_context * fctx = gguf_init_empty();
|
||||
|
||||
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
||||
|
||||
// write file
|
||||
const bool only_meta = false;
|
||||
gguf_write_to_file(fctx, fn.c_str(), only_meta);
|
||||
gguf_write_to_file(fctx, filename, only_meta);
|
||||
gguf_free(fctx);
|
||||
}
|
||||
|
||||
|
@ -681,17 +679,15 @@ static bool load_checkpoint_file(const char * filename, struct my_llama_model *
|
|||
return true;
|
||||
}
|
||||
|
||||
static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train, const char * pattern_it, int iteration, const char * latest) {
|
||||
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||
static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
|
||||
printf("%s: saving to %s\n", __func__, filename);
|
||||
struct gguf_context * fctx = gguf_init_empty();
|
||||
|
||||
save_checkpoint_gguf(fctx, fn_vocab_model, model, train);
|
||||
|
||||
// write file
|
||||
const bool only_meta = false;
|
||||
gguf_write_to_file(fctx, fn.c_str(), only_meta);
|
||||
gguf_write_to_file(fctx, filename, only_meta);
|
||||
gguf_free(fctx);
|
||||
}
|
||||
|
||||
|
@ -1220,25 +1216,50 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
|||
return true;
|
||||
}
|
||||
|
||||
struct opt_callback_data {
|
||||
struct train_params * params;
|
||||
struct train_state * train;
|
||||
struct save_train_files_data {
|
||||
const char * fn_checkpoint_out;
|
||||
const char * fn_model_out;
|
||||
const char * fn_vocab_model;
|
||||
const char * pattern_fn_it;
|
||||
const char * fn_latest;
|
||||
struct my_llama_model * model;
|
||||
struct llama_context * lctx;
|
||||
int last_save_iter;
|
||||
llama_token * tokens_data;
|
||||
size_t tokens_size;
|
||||
size_t * samples_begin;
|
||||
size_t * samples_size;
|
||||
size_t * shuffled_samples_begin;
|
||||
size_t * shuffled_samples_size;
|
||||
size_t samples_count;
|
||||
struct ggml_tensor * tokens_input;
|
||||
struct ggml_tensor * target_logits;
|
||||
struct ggml_tensor * target_probs;
|
||||
int first_iter;
|
||||
int64_t last_time;
|
||||
double millis_per_iter;
|
||||
};
|
||||
|
||||
static void save_train_files(void * vdata, struct train_state * train) {
|
||||
struct save_train_files_data * data = (struct save_train_files_data *) vdata;
|
||||
int64_t iter = train->opt->iter;
|
||||
|
||||
if (strlen(data->fn_checkpoint_out) > 0) {
|
||||
save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model, train);
|
||||
save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model, train);
|
||||
|
||||
}
|
||||
if (strlen(data->fn_model_out) > 0) {
|
||||
save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model);
|
||||
save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model);
|
||||
}
|
||||
}
|
||||
|
||||
struct opt_callback_data {
|
||||
struct train_params * params;
|
||||
struct train_state * train;
|
||||
save_train_files_callback save_cb;
|
||||
void * save_data;
|
||||
struct llama_context * lctx;
|
||||
int last_save_iter;
|
||||
llama_token * tokens_data;
|
||||
size_t tokens_size;
|
||||
size_t * samples_begin;
|
||||
size_t * samples_size;
|
||||
size_t * shuffled_samples_begin;
|
||||
size_t * shuffled_samples_size;
|
||||
size_t samples_count;
|
||||
struct ggml_tensor * tokens_input;
|
||||
struct ggml_tensor * target_logits;
|
||||
struct ggml_tensor * target_probs;
|
||||
int first_iter;
|
||||
int64_t last_time;
|
||||
double millis_per_iter;
|
||||
};
|
||||
|
||||
static void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||
|
@ -1278,15 +1299,10 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
|
|||
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
|
||||
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
|
||||
|
||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, -1, params->fn_latest);
|
||||
if (data->save_cb) {
|
||||
data->save_cb(data->save_data, train);
|
||||
}
|
||||
|
||||
}
|
||||
if (strlen(params->fn_model_out) > 0) {
|
||||
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest);
|
||||
}
|
||||
data->last_save_iter = opt->iter;
|
||||
}
|
||||
|
||||
|
@ -1508,14 +1524,23 @@ int main(int argc, char ** argv) {
|
|||
train_samples_size.size());
|
||||
printf("%s: begin training\n", __func__);
|
||||
|
||||
save_train_files_data save_data;
|
||||
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
|
||||
save_data.fn_model_out = params.fn_model_out;
|
||||
save_data.fn_vocab_model = params.fn_vocab_model;
|
||||
save_data.pattern_fn_it = params.pattern_fn_it;
|
||||
save_data.fn_latest = params.fn_latest;
|
||||
save_data.model = &model;
|
||||
|
||||
struct opt_callback_data opt_cb_data;
|
||||
opt_cb_data.params = ¶ms;
|
||||
opt_cb_data.train = train;
|
||||
opt_cb_data.model = &model;
|
||||
opt_cb_data.lctx = lctx;
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
opt_cb_data.tokens_data = train_tokens.data();
|
||||
opt_cb_data.tokens_size = train_tokens.size();
|
||||
opt_cb_data.params = ¶ms;
|
||||
opt_cb_data.train = train;
|
||||
opt_cb_data.save_cb = &save_train_files;
|
||||
opt_cb_data.save_data = &save_data;
|
||||
opt_cb_data.lctx = lctx;
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
opt_cb_data.tokens_data = train_tokens.data();
|
||||
opt_cb_data.tokens_size = train_tokens.size();
|
||||
opt_cb_data.samples_begin = train_samples_begin.data();
|
||||
opt_cb_data.samples_size = train_samples_size.data();
|
||||
opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
|
||||
|
@ -1620,22 +1645,15 @@ int main(int argc, char ** argv) {
|
|||
printf("%s: total training time=%f seconds\n", __func__, dd);
|
||||
|
||||
int new_iters = opt->iter - opt_cb_data.last_save_iter;
|
||||
train->train_its += new_iters;
|
||||
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
|
||||
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
|
||||
if (new_iters > 0) {
|
||||
train->train_its += new_iters;
|
||||
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
|
||||
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
|
||||
|
||||
if (params.n_examples > 0) {
|
||||
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, -1, params.fn_latest);
|
||||
save_train_files(&save_data, train);
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
}
|
||||
|
||||
if (strlen(params.fn_model_out) > 0) {
|
||||
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, -1, params.fn_latest);
|
||||
}
|
||||
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
|
||||
if (alloc) {
|
||||
ggml_allocr_free(alloc);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue