also save latest finetune output with ITERATION="LATEST" and print where files are saved
saving with LATEST makes it easier to resume training from the latest checkpoint the string "LATEST" can be configured with command line option "--fn-latest STR"
This commit is contained in:
parent
27c24ffa1b
commit
8b4106ae33
1 changed files with 26 additions and 11 deletions
|
@ -1865,9 +1865,10 @@ std::string replace_str(const char * s, const char * needle, const char * replac
|
||||||
return str;
|
return str;
|
||||||
}
|
}
|
||||||
|
|
||||||
void save_checkpoint(struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt, const char * filename, const char * pattern_it, int iteration) {
|
void save_checkpoint(struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt, const char * filename, const char * pattern_it, int iteration, const char * latest) {
|
||||||
std::string sit = std::to_string(iteration);
|
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||||
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||||
|
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||||
struct llama_file file(fn.c_str(), "wb");
|
struct llama_file file(fn.c_str(), "wb");
|
||||||
if (file.fp == NULL) {
|
if (file.fp == NULL) {
|
||||||
return;
|
return;
|
||||||
|
@ -2032,9 +2033,10 @@ bool load_checkpoint(struct my_llama_model * model, struct my_llama_lora * lora,
|
||||||
return (file.fp != NULL);
|
return (file.fp != NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
void save_as_llama_lora(struct my_llama_lora * lora, const char * filename, const char * pattern_it, int iteration) {
|
void save_as_llama_lora(struct my_llama_lora * lora, const char * filename, const char * pattern_it, int iteration, const char * latest) {
|
||||||
std::string sit = std::to_string(iteration);
|
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||||
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||||
|
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||||
struct llama_file file(fn.c_str(), "wb");
|
struct llama_file file(fn.c_str(), "wb");
|
||||||
if (file.fp == NULL) {
|
if (file.fp == NULL) {
|
||||||
return;
|
return;
|
||||||
|
@ -2102,6 +2104,7 @@ struct train_params {
|
||||||
const char * fn_checkpoint_out;
|
const char * fn_checkpoint_out;
|
||||||
const char * fn_lora_out;
|
const char * fn_lora_out;
|
||||||
const char * pattern_fn_it;
|
const char * pattern_fn_it;
|
||||||
|
const char * fn_latest;
|
||||||
|
|
||||||
int save_every;
|
int save_every;
|
||||||
|
|
||||||
|
@ -2173,6 +2176,7 @@ struct train_params get_default_train_params() {
|
||||||
params.fn_checkpoint_out = "checkpoint-ITERATION.bin";
|
params.fn_checkpoint_out = "checkpoint-ITERATION.bin";
|
||||||
params.fn_lora_out = "ggml-lora-ITERATION-f32.bin";
|
params.fn_lora_out = "ggml-lora-ITERATION-f32.bin";
|
||||||
params.pattern_fn_it = "ITERATION";
|
params.pattern_fn_it = "ITERATION";
|
||||||
|
params.fn_latest = "LATEST";
|
||||||
|
|
||||||
params.save_every = 10;
|
params.save_every = 10;
|
||||||
|
|
||||||
|
@ -2248,7 +2252,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
||||||
fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
|
fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
|
||||||
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
||||||
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%s')\n", params->save_every);
|
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
|
||||||
|
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
|
||||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
||||||
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
||||||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||||
|
@ -2352,6 +2357,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->pattern_fn_it = argv[i];
|
params->pattern_fn_it = argv[i];
|
||||||
|
} else if (arg == "--fn-latest") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->fn_latest = argv[i];
|
||||||
} else if (arg == "--save-every") {
|
} else if (arg == "--save-every") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -2669,11 +2680,13 @@ void opt_callback(void * vdata, float * sched) {
|
||||||
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
||||||
if (save_now) {
|
if (save_now) {
|
||||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||||
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter);
|
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
}
|
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
}
|
||||||
if (strlen(params->fn_lora_out) > 0) {
|
if (strlen(params->fn_lora_out) > 0) {
|
||||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter);
|
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
}
|
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
}
|
||||||
data->last_save_iter = opt->iter;
|
data->last_save_iter = opt->iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3038,11 +3051,13 @@ int main(int argc, char ** argv) {
|
||||||
printf("%s: total training time=%f seconds\n", __func__, dd);
|
printf("%s: total training time=%f seconds\n", __func__, dd);
|
||||||
|
|
||||||
if (params.n_examples > 0) {
|
if (params.n_examples > 0) {
|
||||||
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter);
|
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||||
|
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (strlen(params.fn_lora_out) > 0) {
|
if (strlen(params.fn_lora_out) > 0) {
|
||||||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter);
|
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||||
|
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue