add option to save train-text-from-scratch output every N iterations
This commit is contained in:
parent
f3590ad8d9
commit
b26bd4c34c
1 changed files with 83 additions and 6 deletions
|
@ -793,6 +793,15 @@ void shuffle_ints(int * begin, int * end) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string replace_str(const char * s, const char * needle, const char * replacement) {
|
||||||
|
std::string str = s;
|
||||||
|
size_t pos = str.find(needle);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
str.replace(pos, strlen(needle), replacement);
|
||||||
|
}
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
|
||||||
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
|
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
|
||||||
{ \
|
{ \
|
||||||
const std::string skey(key); \
|
const std::string skey(key); \
|
||||||
|
@ -1174,14 +1183,17 @@ void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_mod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
|
void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, const char * pattern_it, int iteration, const char * latest) {
|
||||||
|
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||||
|
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||||
|
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||||
struct gguf_context * fctx = gguf_init_empty();
|
struct gguf_context * fctx = gguf_init_empty();
|
||||||
|
|
||||||
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
||||||
|
|
||||||
// write file
|
// write file
|
||||||
const bool only_meta = false;
|
const bool only_meta = false;
|
||||||
gguf_write_to_file(fctx, filename, only_meta);
|
gguf_write_to_file(fctx, fn.c_str(), only_meta);
|
||||||
gguf_free(fctx);
|
gguf_free(fctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1234,14 +1246,17 @@ bool load_checkpoint_file(const char * filename, struct my_llama_model * model,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
|
void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt, const char * pattern_it, int iteration, const char * latest) {
|
||||||
|
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||||
|
std::string fn = replace_str(filename, pattern_it, sit.c_str());
|
||||||
|
printf("%s: saving to %s\n", __func__, fn.c_str());
|
||||||
struct gguf_context * fctx = gguf_init_empty();
|
struct gguf_context * fctx = gguf_init_empty();
|
||||||
|
|
||||||
save_checkpoint_gguf(fctx, fn_vocab_model, model, opt);
|
save_checkpoint_gguf(fctx, fn_vocab_model, model, opt);
|
||||||
|
|
||||||
// write file
|
// write file
|
||||||
const bool only_meta = false;
|
const bool only_meta = false;
|
||||||
gguf_write_to_file(fctx, filename, only_meta);
|
gguf_write_to_file(fctx, fn.c_str(), only_meta);
|
||||||
gguf_free(fctx);
|
gguf_free(fctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1270,6 +1285,10 @@ struct train_params {
|
||||||
const char * fn_checkpoint_in;
|
const char * fn_checkpoint_in;
|
||||||
const char * fn_checkpoint_out;
|
const char * fn_checkpoint_out;
|
||||||
const char * fn_model_out;
|
const char * fn_model_out;
|
||||||
|
const char * pattern_fn_it;
|
||||||
|
const char * fn_latest;
|
||||||
|
|
||||||
|
int save_every;
|
||||||
|
|
||||||
uint32_t seed;
|
uint32_t seed;
|
||||||
|
|
||||||
|
@ -1329,6 +1348,10 @@ struct train_params get_default_train_params() {
|
||||||
params.fn_checkpoint_in = "checkpoint.bin";
|
params.fn_checkpoint_in = "checkpoint.bin";
|
||||||
params.fn_checkpoint_out = "checkpoint.bin";
|
params.fn_checkpoint_out = "checkpoint.bin";
|
||||||
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
||||||
|
params.pattern_fn_it = "ITERATION";
|
||||||
|
params.fn_latest = "LATEST";
|
||||||
|
|
||||||
|
params.save_every = 10;
|
||||||
|
|
||||||
params.seed = -1;
|
params.seed = -1;
|
||||||
|
|
||||||
|
@ -1392,6 +1415,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
||||||
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
||||||
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
|
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
|
||||||
|
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
||||||
|
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
|
||||||
|
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
|
||||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
||||||
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
||||||
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
|
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
|
||||||
|
@ -1481,6 +1507,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->fn_model_out = argv[i];
|
params->fn_model_out = argv[i];
|
||||||
|
} else if (arg == "--pattern-fn-it") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->pattern_fn_it = argv[i];
|
||||||
|
} else if (arg == "--fn-latest") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->fn_latest = argv[i];
|
||||||
|
} else if (arg == "--save-every") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->save_every = std::stoi(argv[i]);
|
||||||
} else if (arg == "-s" || arg == "--seed") {
|
} else if (arg == "-s" || arg == "--seed") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1722,7 +1766,9 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
struct opt_callback_data {
|
struct opt_callback_data {
|
||||||
struct train_params * params;
|
struct train_params * params;
|
||||||
struct ggml_opt_context * opt;
|
struct ggml_opt_context * opt;
|
||||||
|
struct my_llama_model * model;
|
||||||
struct llama_context * lctx;
|
struct llama_context * lctx;
|
||||||
|
int last_save_iter;
|
||||||
llama_token * tokens_data;
|
llama_token * tokens_data;
|
||||||
size_t tokens_size;
|
size_t tokens_size;
|
||||||
int * samples_data;
|
int * samples_data;
|
||||||
|
@ -1738,6 +1784,26 @@ void opt_callback(void * vdata, float * sched) {
|
||||||
struct train_params * params = data->params;
|
struct train_params * params = data->params;
|
||||||
struct ggml_opt_context * opt = data->opt;
|
struct ggml_opt_context * opt = data->opt;
|
||||||
int n_batch = params->n_batch;
|
int n_batch = params->n_batch;
|
||||||
|
int n_ctx = params->n_ctx;
|
||||||
|
|
||||||
|
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
||||||
|
if (save_now) {
|
||||||
|
int new_iters = opt->iter - data->last_save_iter;
|
||||||
|
data->model->train_its += new_iters;
|
||||||
|
data->model->train_samples += new_iters * n_batch;
|
||||||
|
data->model->train_tokens += new_iters * n_batch * n_ctx;
|
||||||
|
|
||||||
|
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||||
|
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
|
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
|
||||||
|
}
|
||||||
|
if (strlen(params->fn_model_out) > 0) {
|
||||||
|
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
|
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
}
|
||||||
|
data->last_save_iter = opt->iter;
|
||||||
|
}
|
||||||
|
|
||||||
*sched = (opt->iter < params->warmup)
|
*sched = (opt->iter < params->warmup)
|
||||||
? (float) opt->iter / (float) params->warmup
|
? (float) opt->iter / (float) params->warmup
|
||||||
|
@ -1929,7 +1995,9 @@ int main(int argc, char ** argv) {
|
||||||
struct opt_callback_data opt_cb_data;
|
struct opt_callback_data opt_cb_data;
|
||||||
opt_cb_data.params = ¶ms;
|
opt_cb_data.params = ¶ms;
|
||||||
opt_cb_data.opt = opt;
|
opt_cb_data.opt = opt;
|
||||||
|
opt_cb_data.model = &model;
|
||||||
opt_cb_data.lctx = lctx;
|
opt_cb_data.lctx = lctx;
|
||||||
|
opt_cb_data.last_save_iter = opt->iter;
|
||||||
opt_cb_data.tokens_data = train_tokens.data();
|
opt_cb_data.tokens_data = train_tokens.data();
|
||||||
opt_cb_data.tokens_size = train_tokens.size();
|
opt_cb_data.tokens_size = train_tokens.size();
|
||||||
opt_cb_data.samples_data = train_samples.data();
|
opt_cb_data.samples_data = train_samples.data();
|
||||||
|
@ -2038,14 +2106,23 @@ int main(int argc, char ** argv) {
|
||||||
double dd = (double) d * 1e-3;
|
double dd = (double) d * 1e-3;
|
||||||
printf("%s: total training time=%f seconds\n", __func__, dd);
|
printf("%s: total training time=%f seconds\n", __func__, dd);
|
||||||
|
|
||||||
|
int new_iters = opt->iter - opt_cb_data.last_save_iter;
|
||||||
|
model.train_its += new_iters;
|
||||||
|
model.train_samples += new_iters * n_batch;
|
||||||
|
model.train_tokens += new_iters * n_batch * n_tokens;
|
||||||
|
|
||||||
if (params.n_examples > 0) {
|
if (params.n_examples > 0) {
|
||||||
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt);
|
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||||
|
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, -1, params.fn_latest);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (strlen(params.fn_model_out) > 0) {
|
if (strlen(params.fn_model_out) > 0) {
|
||||||
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model);
|
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||||
|
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, -1, params.fn_latest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opt_cb_data.last_save_iter = opt->iter;
|
||||||
|
|
||||||
if (alloc) {
|
if (alloc) {
|
||||||
ggml_allocr_free(alloc);
|
ggml_allocr_free(alloc);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue