move common train params into common/train

This commit is contained in:
xaedes 2023-09-16 18:45:59 +02:00
parent ee27333b16
commit e9758ae1d2
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 552 additions and 821 deletions

View file

@ -1006,3 +1006,326 @@ std::string get_train_filename(const char * filename, const char * pattern_it, c
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
return replace_str(filename, pattern_it, sit.c_str()); return replace_str(filename, pattern_it, sit.c_str());
} }
struct train_params_common get_default_train_params_common() {
struct train_params_common params;
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.gguf";
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
params.pattern_fn_it = "ITERATION";
params.fn_latest = "LATEST";
params.print_usage = false;
params.save_every = 10;
params.seed = -1;
params.n_ctx = 128;
params.n_threads = 6;
params.n_batch = 8;
params.n_gradient_accumulation = 1;
params.custom_n_ctx = false;
params.use_flash = true;
params.use_checkpointing = true;
params.sample_start = "";
params.include_sample_start = false;
params.escape = false;
params.overlapping_samples = false;
params.fill_with_next_samples = false;
params.separate_with_eos = false;
params.separate_with_bos = true;
params.force_reshuffle = false;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
params.adam_decay = 1e-1f;
params.adam_decay_min_ndim = 2;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
return params;
}
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params) {
// fprintf(stderr, "usage: %s [options]\n", argv[0]);
// fprintf(stderr, "\n");
// fprintf(stderr, "options:\n");
// fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, "\n");
}
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param) {
int& i = *idx;
char * arg = argv[i];
if (arg == "--train-data") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--pattern-fn-it") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->pattern_fn_it = argv[i];
} else if (arg == "--fn-latest") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_latest = argv[i];
} else if (arg == "--save-every") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->save_every = std::stoi(argv[i]);
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_ctx = std::stoi(argv[i]);
params->custom_n_ctx = true;
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "--grad-acc") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "--sample-start") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->sample_start = std::string(argv[i]);
} else if (arg == "--escape") {
params->escape = true;
} else if (arg == "--include-sample-start") {
params->include_sample_start = true;
} else if (arg == "--overlapping-samples") {
params->overlapping_samples = true;
} else if (arg == "--fill-with-next-samples") {
params->fill_with_next_samples = true;
} else if (arg == "--separate-with-eos") {
params->separate_with_eos = true;
} else if (arg == "--separate-with-bos") {
params->separate_with_bos = true;
} else if (arg == "--no-separate-with-eos") {
params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->cos_decay_steps = std::stoi(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-min") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->cos_decay_min = std::stof(argv[i]);
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-decay-min-ndim") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_decay_min_ndim = std::stoi(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-h" || arg == "--help") {
params->print_usage = true;
return true;
} else {
return false;
}
return true;
}
void finish_processing_train_args(struct train_params_common * params) {
if (params->escape) {
process_escapes(params->sample_start);
}
}

View file

@ -26,9 +26,69 @@ struct train_state {
size_t shuffle_next_sample; size_t shuffle_next_sample;
}; };
struct train_params_common {
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * pattern_fn_it;
const char * fn_latest;
bool print_usage;
int save_every;
uint32_t seed;
int n_ctx;
int n_threads;
int n_batch;
int n_gradient_accumulation;
bool custom_n_ctx;
bool use_flash;
bool use_checkpointing;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
};
struct train_state * init_train_state(int seed); struct train_state * init_train_state(int seed);
void free_train_state(struct train_state * state); void free_train_state(struct train_state * state);
struct train_params_common get_default_train_params_common();
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
void finish_processing_train_args(struct train_params_common * params);
struct random_normal_distribution; struct random_normal_distribution;
struct random_uniform_distribution; struct random_uniform_distribution;

View file

@ -1197,24 +1197,10 @@ static void save_as_llama_lora(const char * filename, struct my_llama_lora * lor
} }
struct train_params { struct train_params {
struct train_params_common common;
const char * fn_model_base; const char * fn_model_base;
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * fn_lora_out; const char * fn_lora_out;
const char * pattern_fn_it;
const char * fn_latest;
int save_every;
uint32_t seed;
int n_ctx;
int n_threads;
int n_batch;
int n_gradient_accumulation;
bool custom_n_ctx;
bool only_write_lora; bool only_write_lora;
@ -1255,61 +1241,13 @@ struct train_params {
bool custom_n_rank_tok_embeddings; bool custom_n_rank_tok_embeddings;
bool custom_n_rank_norm; bool custom_n_rank_norm;
bool custom_n_rank_output; bool custom_n_rank_output;
bool use_flash;
bool use_checkpointing;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
}; };
static struct train_params get_default_train_params() { static struct train_params get_default_train_params() {
struct train_params params; struct train_params params;
params.common = get_default_train_params_common();
params.fn_model_base = ""; params.fn_model_base = "";
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.gguf";
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf"; params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
params.pattern_fn_it = "ITERATION";
params.fn_latest = "LATEST";
params.save_every = 10;
params.seed = -1;
params.n_ctx = 128;
params.n_threads = 6;
params.n_batch = 8;
params.n_gradient_accumulation = 1;
params.custom_n_ctx = false;
params.only_write_lora = false; params.only_write_lora = false;
@ -1351,59 +1289,18 @@ static struct train_params get_default_train_params() {
params.custom_n_rank_norm = false; params.custom_n_rank_norm = false;
params.custom_n_rank_output = false; params.custom_n_rank_output = false;
params.use_flash = true;
params.use_checkpointing = true;
params.sample_start = "";
params.include_sample_start = false;
params.escape = false;
params.overlapping_samples = false;
params.fill_with_next_samples = false;
params.separate_with_eos = false;
params.separate_with_bos = true;
params.force_reshuffle = false;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
params.adam_decay = 1e-1f;
params.adam_decay_min_ndim = 2;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
return params; return params;
} }
static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) { static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base); fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base);
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out); fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
fprintf(stderr, " --only-write-lora only save llama lora, don't do any training\n"); fprintf(stderr, " --only-write-lora only save llama lora, don't do any training\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
@ -1421,39 +1318,8 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n"); fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n"); fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n"); fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); print_common_train_usage(argc, argv, &params->common);
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, "\n");
} }
static bool train_params_parse(int argc, char ** argv, struct train_params * params) { static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
@ -1468,87 +1334,27 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
std::replace(arg.begin(), arg.end(), '_', '-'); std::replace(arg.begin(), arg.end(), '_', '-');
} }
if (arg == "--model-base") { if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
if (invalid_param) {
break;
} else if (params->common.print_usage) {
train_print_usage(argc, argv, &default_params);
exit(0);
}
} else if (arg == "--model-base") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->fn_model_base = argv[i]; params->fn_model_base = argv[i];
} else if (arg == "--train-data") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--lora-out") { } else if (arg == "--lora-out") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->fn_lora_out = argv[i]; params->fn_lora_out = argv[i];
} else if (arg == "--pattern-fn-it") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->pattern_fn_it = argv[i];
} else if (arg == "--fn-latest") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_latest = argv[i];
} else if (arg == "--save-every") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->save_every = std::stoi(argv[i]);
} else if (arg == "--only-write-lora") { } else if (arg == "--only-write-lora") {
params->only_write_lora = true; params->only_write_lora = true;
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_ctx = std::stoi(argv[i]);
params->custom_n_ctx = true;
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "--grad-acc") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "--norm-rms-eps") { } else if (arg == "--norm-rms-eps") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1667,141 +1473,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
} }
params->n_rank_w3 = std::stoi(argv[i]); params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true; params->custom_n_rank_w3 = true;
} else if (arg == "--sample-start") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->sample_start = std::string(argv[i]);
} else if (arg == "--escape") {
params->escape = true;
} else if (arg == "--include-sample-start") {
params->include_sample_start = true;
} else if (arg == "--overlapping-samples") {
params->overlapping_samples = true;
} else if (arg == "--fill-with-next-samples") {
params->fill_with_next_samples = true;
} else if (arg == "--separate-with-eos") {
params->separate_with_eos = true;
} else if (arg == "--separate-with-bos") {
params->separate_with_bos = true;
} else if (arg == "--no-separate-with-eos") {
params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_steps = std::stof(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-min") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_min = std::stof(argv[i]);
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-decay-min-ndim") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay_min_ndim = std::stoi(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);
@ -1813,9 +1484,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);
exit(1); exit(1);
} }
if (params->escape) { finish_processing_train_args(&params->common);
process_escapes(params->sample_start);
}
return true; return true;
} }
@ -1844,31 +1513,31 @@ static void save_train_files(void * vdata, struct train_state * train) {
} }
struct opt_callback_data { struct opt_callback_data {
struct train_params * params; struct train_params_common * params;
struct train_state * train; struct train_state * train;
save_train_files_callback save_cb; save_train_files_callback save_cb;
void * save_data; void * save_data;
struct llama_context * lctx; struct llama_context * lctx;
int last_save_iter; int last_save_iter;
llama_token * tokens_data; llama_token * tokens_data;
size_t tokens_size; size_t tokens_size;
size_t * samples_begin; size_t * samples_begin;
size_t * samples_size; size_t * samples_size;
size_t * shuffled_samples_begin; size_t * shuffled_samples_begin;
size_t * shuffled_samples_size; size_t * shuffled_samples_size;
size_t samples_count; size_t samples_count;
struct ggml_tensor * tokens_input; struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs; struct ggml_tensor * target_probs;
int first_iter; int first_iter;
int64_t last_time; int64_t last_time;
double millis_per_iter; double millis_per_iter;
}; };
static void opt_callback(void * vdata, int accum_step, float * sched) { static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params; struct train_params_common * params = data->params;
struct train_state * train = data->train; struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt; struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch; int n_batch = params->n_batch;
int n_ctx = params->n_ctx; int n_ctx = params->n_ctx;
@ -2019,11 +1688,11 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.common.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL); params.common.seed = time(NULL);
} }
printf("%s: seed: %u\n", __func__, params.seed); printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.seed); srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params(); struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = false; llama_params.vocab_only = false;
@ -2033,11 +1702,11 @@ int main(int argc, char ** argv) {
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct my_llama_model model; struct my_llama_model model;
init_model(lmodel, &model, params.n_ctx); init_model(lmodel, &model, params.common.n_ctx);
struct my_llama_lora lora; struct my_llama_lora lora;
struct train_state * train = init_train_state(params.seed); struct train_state * train = init_train_state(params.common.seed);
struct ggml_opt_context * opt = train->opt; struct ggml_opt_context * opt = train->opt;
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams); load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
@ -2083,30 +1752,30 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false; opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false; opt->params.print_backward_graph = false;
opt->params.n_threads = params.n_threads; opt->params.n_threads = params.common.n_threads;
opt->params.past = params.opt_past; opt->params.past = params.common.opt_past;
opt->params.delta = params.opt_delta; opt->params.delta = params.common.opt_delta;
opt->params.max_no_improvement = params.opt_max_no_improvement; opt->params.max_no_improvement = params.common.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.n_gradient_accumulation; opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
opt->params.adam.n_iter = params.adam_n_iter; opt->params.adam.n_iter = params.common.adam_n_iter;
opt->params.adam.sched = 1.0f; opt->params.adam.sched = 1.0f;
opt->params.adam.alpha = params.adam_alpha; opt->params.adam.alpha = params.common.adam_alpha;
opt->params.adam.decay = params.adam_decay; opt->params.adam.decay = params.common.adam_decay;
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim; opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
opt->params.adam.beta1 = params.adam_beta1; opt->params.adam.beta1 = params.common.adam_beta1;
opt->params.adam.beta2 = params.adam_beta2; opt->params.adam.beta2 = params.common.adam_beta2;
opt->params.adam.gclip = params.adam_gclip; opt->params.adam.gclip = params.common.adam_gclip;
opt->params.adam.eps_f = params.adam_eps_f; opt->params.adam.eps_f = params.common.adam_eps_f;
ggml_allocr * alloc = NULL; ggml_allocr * alloc = NULL;
printf("%s: init model\n", __func__); printf("%s: init model\n", __func__);
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, train); bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
if (existed) { if (existed) {
// overwrite last n_ctx with user provided n_ctx // overwrite last n_ctx with user provided n_ctx
if (params.custom_n_ctx) { if (params.common.custom_n_ctx) {
model.hparams.n_ctx = params.n_ctx; model.hparams.n_ctx = params.common.n_ctx;
} }
const bool opt_param_count_changed = ( const bool opt_param_count_changed = (
@ -2124,7 +1793,7 @@ int main(int argc, char ** argv) {
|| (lora.hparams.n_rank_output != n_rank_output) || (lora.hparams.n_rank_output != n_rank_output)
); );
const bool opt_past_changed = opt->params.past != params.opt_past; const bool opt_past_changed = opt->params.past != params.common.opt_past;
if (opt_param_count_changed) { if (opt_param_count_changed) {
print_lora_params(&lora.hparams); print_lora_params(&lora.hparams);
@ -2139,7 +1808,7 @@ int main(int argc, char ** argv) {
} }
} else { // existed == false } else { // existed == false
init_lora(&model, &lora); init_lora(&model, &lora);
randomize_lora(&lora, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
if (!params.only_write_lora) { if (!params.only_write_lora) {
ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora)); ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
} }
@ -2159,8 +1828,8 @@ int main(int argc, char ** argv) {
save_train_files_data save_data; save_train_files_data save_data;
save_data.fn_checkpoint_out = ""; save_data.fn_checkpoint_out = "";
save_data.fn_lora_out = params.fn_lora_out; save_data.fn_lora_out = params.fn_lora_out;
save_data.pattern_fn_it = params.pattern_fn_it; save_data.pattern_fn_it = params.common.pattern_fn_it;
save_data.fn_latest = params.fn_latest; save_data.fn_latest = params.common.fn_latest;
save_data.model = &model; save_data.model = &model;
save_data.lora = &lora; save_data.lora = &lora;
@ -2175,7 +1844,7 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx; int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
int n_batch = params.n_batch; int n_batch = params.common.n_batch;
printf("%s: opt iter %d\n", __func__, opt->iter); printf("%s: opt iter %d\n", __func__, opt->iter);
@ -2215,7 +1884,7 @@ int main(int argc, char ** argv) {
size_t estimated_compute_size_wo_data = ( size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2 ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*( + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
params.use_checkpointing ? 3 : 2 params.common.use_checkpointing ? 3 : 2
) )
); );
struct ggml_init_params ctx_compute_params = { struct ggml_init_params ctx_compute_params = {
@ -2242,7 +1911,7 @@ int main(int argc, char ** argv) {
gf = ggml_new_graph(ctx_compute); gf = ggml_new_graph(ctx_compute);
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute); gb = ggml_new_graph(ctx_compute);
gb_tmp = params.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute) ? ggml_new_graph(ctx_compute)
: NULL; : NULL;
loss = llama_build_lora_finetune_graphs( loss = llama_build_lora_finetune_graphs(
@ -2250,8 +1919,8 @@ int main(int argc, char ** argv) {
gf, gb, gb_tmp, gf, gb, gb_tmp,
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.use_flash, params.common.use_flash,
params.use_checkpointing params.common.use_checkpointing
); );
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment; size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
if (max_compute_size < best_compute_size) { if (max_compute_size < best_compute_size) {
@ -2275,7 +1944,7 @@ int main(int argc, char ** argv) {
gf = ggml_new_graph(ctx_compute); gf = ggml_new_graph(ctx_compute);
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph(ctx_compute); gb = ggml_new_graph(ctx_compute);
gb_tmp = params.use_checkpointing gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute) ? ggml_new_graph(ctx_compute)
: NULL; : NULL;
loss = llama_build_lora_finetune_graphs( loss = llama_build_lora_finetune_graphs(
@ -2283,8 +1952,8 @@ int main(int argc, char ** argv) {
gf, gb, gb_tmp, gf, gb, gb_tmp,
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.use_flash, params.common.use_flash,
params.use_checkpointing params.common.use_checkpointing
); );
ggml_allocr_free(alloc); ggml_allocr_free(alloc);
@ -2294,10 +1963,10 @@ int main(int argc, char ** argv) {
std::vector<size_t> train_samples_size; std::vector<size_t> train_samples_size;
printf("%s: tokenize training data\n", __func__); printf("%s: tokenize training data\n", __func__);
tokenize_file(lctx, tokenize_file(lctx,
params.fn_train_data, params.common.fn_train_data,
params.sample_start, params.common.sample_start,
params.include_sample_start, params.common.include_sample_start,
params.overlapping_samples, params.common.overlapping_samples,
n_tokens, n_tokens,
train_tokens, train_tokens,
train_samples_begin, train_samples_begin,
@ -2318,16 +1987,16 @@ int main(int argc, char ** argv) {
} }
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size()); const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
if (changed_train_data) { if (changed_train_data) {
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
} }
if (params.force_reshuffle) { if (params.common.force_reshuffle) {
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
} }
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed); train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
train->shuffle_sample_count = train_samples_size.size(); train->shuffle_sample_count = train_samples_size.size();
train->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
train->shuffle_samples_hash = shuffle_samples_hash; train->shuffle_samples_hash = shuffle_samples_hash;
@ -2347,15 +2016,15 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
save_train_files_data save_data; save_train_files_data save_data;
save_data.fn_checkpoint_out = params.fn_checkpoint_out; save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
save_data.fn_lora_out = params.fn_lora_out; save_data.fn_lora_out = params.fn_lora_out;
save_data.pattern_fn_it = params.pattern_fn_it; save_data.pattern_fn_it = params.common.pattern_fn_it;
save_data.fn_latest = params.fn_latest; save_data.fn_latest = params.common.fn_latest;
save_data.model = &model; save_data.model = &model;
save_data.lora = &lora; save_data.lora = &lora;
struct opt_callback_data opt_cb_data; struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params; opt_cb_data.params = &params.common;
opt_cb_data.train = train; opt_cb_data.train = train;
opt_cb_data.save_cb = &save_train_files; opt_cb_data.save_cb = &save_train_files;
opt_cb_data.save_data = &save_data; opt_cb_data.save_data = &save_data;
@ -2375,7 +2044,7 @@ int main(int argc, char ** argv) {
opt_cb_data.millis_per_iter = 0.0; opt_cb_data.millis_per_iter = 0.0;
// measure required memory for work buffer // measure required memory for work buffer
size_t max_work_size = ggml_graph_plan(gb, params.n_threads).work_size + GGML_OBJECT_SIZE; size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
printf("%s: max_work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f)); printf("%s: max_work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
// context for work buffer // context for work buffer

View file

@ -692,17 +692,10 @@ static void save_checkpoint_file(const char * filename, const char * fn_vocab_mo
} }
struct train_params { struct train_params {
struct train_params_common common;
const char * fn_vocab_model; const char * fn_vocab_model;
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * fn_model_out; const char * fn_model_out;
const char * pattern_fn_it;
const char * fn_latest;
int save_every;
uint32_t seed;
int n_ctx; int n_ctx;
int n_embd; int n_embd;
@ -710,10 +703,7 @@ struct train_params {
int n_layer; int n_layer;
int n_ff; int n_ff;
int n_threads;
int n_examples; int n_examples;
int n_batch;
int n_gradient_accumulation;
float f_norm_rms_eps; float f_norm_rms_eps;
float rope_freq_base; float rope_freq_base;
@ -721,40 +711,8 @@ struct train_params {
int print_info_interval; int print_info_interval;
bool use_flash;
bool use_checkpointing;
bool use_alloc; bool use_alloc;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
int mem_model_gb; int mem_model_gb;
int mem_compute_gb; int mem_compute_gb;
int mem_compute0_gb; int mem_compute0_gb;
@ -762,17 +720,9 @@ struct train_params {
struct train_params get_default_train_params() { struct train_params get_default_train_params() {
struct train_params params; struct train_params params;
params.common = get_default_train_params_common();
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin"; params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.bin";
params.fn_checkpoint_out = "checkpoint.bin";
params.fn_model_out = "ggml-checkpoint-f32.bin"; params.fn_model_out = "ggml-checkpoint-f32.bin";
params.pattern_fn_it = "ITERATION";
params.fn_latest = "LATEST";
params.save_every = 10;
params.seed = -1;
params.n_ctx = 128; params.n_ctx = 128;
params.n_embd = 256; params.n_embd = 256;
@ -780,10 +730,7 @@ struct train_params get_default_train_params() {
params.n_layer = 16; params.n_layer = 16;
params.n_ff = 768; params.n_ff = 768;
params.n_threads = 6;
params.n_examples = 1; params.n_examples = 1;
params.n_batch = 8;
params.n_gradient_accumulation = 1;
params.f_norm_rms_eps = 1e-5f; params.f_norm_rms_eps = 1e-5f;
params.rope_freq_base = 10000.0f; params.rope_freq_base = 10000.0f;
@ -791,60 +738,22 @@ struct train_params get_default_train_params() {
params.print_info_interval = 1; params.print_info_interval = 1;
params.use_flash = true;
params.use_checkpointing = true;
params.use_alloc = true; params.use_alloc = true;
params.sample_start = "";
params.include_sample_start = false;
params.escape = false;
params.overlapping_samples = false;
params.fill_with_next_samples = false;
params.separate_with_eos = false;
params.separate_with_bos = true;
params.force_reshuffle = false;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
params.adam_decay = 1e-1f;
params.adam_decay_min_ndim = 2;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
params.mem_model_gb = 2; params.mem_model_gb = 2;
params.mem_compute_gb = 24; params.mem_compute_gb = 24;
params.mem_compute0_gb = 8; params.mem_compute0_gb = 8;
return params; return params;
} }
static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) { static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model); fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out); fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd); fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff); fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head); fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
@ -852,49 +761,15 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
fprintf(stderr, " --no-alloc Don't use allocator\n"); fprintf(stderr, " --no-alloc Don't use allocator\n");
fprintf(stderr, " --use-alloc Use allocator (default)\n"); fprintf(stderr, " --use-alloc Use allocator (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb); fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
fprintf(stderr, "\n");
print_common_train_usage(argc, argv, &params->common);
} }
static bool train_params_parse(int argc, char ** argv, struct train_params * params) { static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
@ -909,66 +784,25 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
std::replace(arg.begin(), arg.end(), '_', '-'); std::replace(arg.begin(), arg.end(), '_', '-');
} }
if (arg == "--vocab-model") { if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
if (invalid_param) {
break;
} else if (params->common.print_usage) {
train_print_usage(argc, argv, &default_params);
exit(0);
}
} else if (arg == "--vocab-model") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->fn_vocab_model = argv[i]; params->fn_vocab_model = argv[i];
} else if (arg == "--train-data") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--model-out") { } else if (arg == "--model-out") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->fn_model_out = argv[i]; params->fn_model_out = argv[i];
} else if (arg == "--pattern-fn-it") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->pattern_fn_it = argv[i];
} else if (arg == "--fn-latest") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_latest = argv[i];
} else if (arg == "--save-every") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->save_every = std::stoi(argv[i]);
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_ctx = std::stoi(argv[i]);
} else if (arg == "--embd") { } else if (arg == "--embd") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1011,24 +845,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break; break;
} }
params->rope_freq_scale = std::stof(argv[i]); params->rope_freq_scale = std::stof(argv[i]);
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "--grad-acc") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "-n" || arg == "--examples") { } else if (arg == "-n" || arg == "--examples") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1041,142 +857,10 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break; break;
} }
params->print_info_interval = std::stoi(argv[i]); params->print_info_interval = std::stoi(argv[i]);
} else if (arg == "--sample-start") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->sample_start = std::string(argv[i]);
} else if (arg == "--escape") {
params->escape = true;
} else if (arg == "--include-sample-start") {
params->include_sample_start = true;
} else if (arg == "--overlapping-samples") {
params->overlapping_samples = true;
} else if (arg == "--fill-with-next-samples") {
params->fill_with_next_samples = true;
} else if (arg == "--separate-with-eos") {
params->separate_with_eos = true;
} else if (arg == "--separate-with-bos") {
params->separate_with_bos = true;
} else if (arg == "--no-separate-with-eos") {
params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--no-alloc") { } else if (arg == "--no-alloc") {
params->use_alloc = false; params->use_alloc = false;
} else if (arg == "--use-alloc") { } else if (arg == "--use-alloc") {
params->use_alloc = true; params->use_alloc = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_steps = std::stof(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-min") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_min = std::stof(argv[i]);
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-decay-min-ndim") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay_min_ndim = std::stoi(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--mem-model") { } else if (arg == "--mem-model") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1195,9 +879,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break; break;
} }
params->mem_compute0_gb = std::stoi(argv[i]); params->mem_compute0_gb = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);
@ -1209,9 +890,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);
exit(1); exit(1);
} }
if (params->escape) { finish_processing_train_args(&params->common);
process_escapes(params->sample_start);
}
return true; return true;
} }
@ -1241,32 +920,32 @@ static void save_train_files(void * vdata, struct train_state * train) {
} }
struct opt_callback_data { struct opt_callback_data {
struct train_params * params; struct train_params_common * params;
struct train_state * train; struct train_state * train;
save_train_files_callback save_cb; save_train_files_callback save_cb;
void * save_data; void * save_data;
struct llama_context * lctx; struct llama_context * lctx;
int last_save_iter; int last_save_iter;
llama_token * tokens_data; llama_token * tokens_data;
size_t tokens_size; size_t tokens_size;
size_t * samples_begin; size_t * samples_begin;
size_t * samples_size; size_t * samples_size;
size_t * shuffled_samples_begin; size_t * shuffled_samples_begin;
size_t * shuffled_samples_size; size_t * shuffled_samples_size;
size_t samples_count; size_t samples_count;
struct ggml_tensor * tokens_input; struct ggml_tensor * tokens_input;
struct ggml_tensor * target_logits; struct ggml_tensor * target_logits;
struct ggml_tensor * target_probs; struct ggml_tensor * target_probs;
int first_iter; int first_iter;
int64_t last_time; int64_t last_time;
double millis_per_iter; double millis_per_iter;
}; };
static void opt_callback(void * vdata, int accum_step, float * sched) { static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params; struct train_params_common * params = data->params;
struct train_state * train = data->train; struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt; struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch; int n_batch = params->n_batch;
int n_ctx = params->n_ctx; int n_ctx = params->n_ctx;
@ -1385,11 +1064,11 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.common.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL); params.common.seed = time(NULL);
} }
printf("%s: seed: %u\n", __func__, params.seed); printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.seed); srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params(); struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true; llama_params.vocab_only = true;
@ -1399,7 +1078,7 @@ int main(int argc, char ** argv) {
struct my_llama_model model; struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx); model.hparams.n_vocab = llama_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx; model.hparams.n_ctx = params.common.n_ctx;
model.hparams.n_embd = params.n_embd; model.hparams.n_embd = params.n_embd;
model.hparams.n_head = params.n_head; model.hparams.n_head = params.n_head;
model.hparams.n_layer = params.n_layer; model.hparams.n_layer = params.n_layer;
@ -1421,34 +1100,34 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx; int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
int n_batch = params.n_batch; int n_batch = params.common.n_batch;
struct train_state * train = init_train_state(params.seed); struct train_state * train = init_train_state(params.common.seed);
struct ggml_opt_context * opt = train->opt; struct ggml_opt_context * opt = train->opt;
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
opt_params_adam.print_forward_graph = false; opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false; opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = params.n_threads; opt_params_adam.n_threads = params.common.n_threads;
opt_params_adam.past = params.opt_past; opt_params_adam.past = params.common.opt_past;
opt_params_adam.delta = params.opt_delta; opt_params_adam.delta = params.common.opt_delta;
opt_params_adam.max_no_improvement = params.opt_max_no_improvement; opt_params_adam.max_no_improvement = params.common.opt_max_no_improvement;
opt_params_adam.n_gradient_accumulation = params.n_gradient_accumulation; opt_params_adam.n_gradient_accumulation = params.common.n_gradient_accumulation;
opt_params_adam.adam.n_iter = params.adam_n_iter; opt_params_adam.adam.n_iter = params.common.adam_n_iter;
opt_params_adam.adam.sched = 1.0f; opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.adam_alpha; opt_params_adam.adam.alpha = params.common.adam_alpha;
opt_params_adam.adam.decay = params.adam_decay; opt_params_adam.adam.decay = params.common.adam_decay;
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim; opt_params_adam.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
opt_params_adam.adam.beta1 = params.adam_beta1; opt_params_adam.adam.beta1 = params.common.adam_beta1;
opt_params_adam.adam.beta2 = params.adam_beta2; opt_params_adam.adam.beta2 = params.common.adam_beta2;
opt_params_adam.adam.gclip = params.adam_gclip; opt_params_adam.adam.gclip = params.common.adam_gclip;
opt_params_adam.adam.eps_f = params.adam_eps_f; opt_params_adam.adam.eps_f = params.common.adam_eps_f;
opt->ctx = model.ctx; opt->ctx = model.ctx;
opt->params = opt_params_adam; opt->params = opt_params_adam;
printf("%s: init model\n", __func__); printf("%s: init model\n", __func__);
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, train); bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
if (!existed) { if (!existed) {
init_model(&model); init_model(&model);
} }
@ -1461,7 +1140,7 @@ int main(int argc, char ** argv) {
bool from_scratch = !existed; bool from_scratch = !existed;
if (from_scratch) { if (from_scratch) {
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
} }
printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx)); printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
@ -1485,10 +1164,10 @@ int main(int argc, char ** argv) {
std::vector<size_t> train_samples_size; std::vector<size_t> train_samples_size;
printf("%s: tokenize training data\n", __func__); printf("%s: tokenize training data\n", __func__);
tokenize_file(lctx, tokenize_file(lctx,
params.fn_train_data, params.common.fn_train_data,
params.sample_start, params.common.sample_start,
params.include_sample_start, params.common.include_sample_start,
params.overlapping_samples, params.common.overlapping_samples,
n_tokens, n_tokens,
train_tokens, train_tokens,
train_samples_begin, train_samples_begin,
@ -1497,16 +1176,16 @@ int main(int argc, char ** argv) {
printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size()); printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size()); const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
if (changed_train_data) { if (changed_train_data) {
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
} }
if (params.force_reshuffle) { if (params.common.force_reshuffle) {
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
} }
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed); train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
train->shuffle_sample_count = train_samples_size.size(); train->shuffle_sample_count = train_samples_size.size();
train->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
train->shuffle_samples_hash = shuffle_samples_hash; train->shuffle_samples_hash = shuffle_samples_hash;
@ -1525,15 +1204,15 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
save_train_files_data save_data; save_train_files_data save_data;
save_data.fn_checkpoint_out = params.fn_checkpoint_out; save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
save_data.fn_model_out = params.fn_model_out; save_data.fn_model_out = params.fn_model_out;
save_data.fn_vocab_model = params.fn_vocab_model; save_data.fn_vocab_model = params.fn_vocab_model;
save_data.pattern_fn_it = params.pattern_fn_it; save_data.pattern_fn_it = params.common.pattern_fn_it;
save_data.fn_latest = params.fn_latest; save_data.fn_latest = params.common.fn_latest;
save_data.model = &model; save_data.model = &model;
struct opt_callback_data opt_cb_data; struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params; opt_cb_data.params = &params.common;
opt_cb_data.train = train; opt_cb_data.train = train;
opt_cb_data.save_cb = &save_train_files; opt_cb_data.save_cb = &save_train_files;
opt_cb_data.save_data = &save_data; opt_cb_data.save_data = &save_data;
@ -1587,7 +1266,7 @@ int main(int argc, char ** argv) {
struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_cgraph * gb = ggml_new_graph(ctx0); struct ggml_cgraph * gb = ggml_new_graph(ctx0);
struct ggml_cgraph * gb_tmp = params.use_checkpointing struct ggml_cgraph * gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx0) ? ggml_new_graph(ctx0)
: NULL; : NULL;
@ -1601,21 +1280,21 @@ int main(int argc, char ** argv) {
gf, gb, gb_tmp, gf, gb, gb_tmp,
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.use_flash, params.common.use_flash,
params.use_checkpointing params.common.use_checkpointing
); );
size_t used_mem_before_opt = ggml_used_mem(ctx0); size_t used_mem_before_opt = ggml_used_mem(ctx0);
opt->params.adam.sched = learning_schedule( opt->params.adam.sched = learning_schedule(
opt->iter, opt->iter,
params.warmup, params.common.warmup,
params.cos_decay_steps, params.common.cos_decay_steps,
params.adam_alpha, params.common.adam_alpha,
params.adam_min_alpha, params.common.adam_min_alpha,
params.cos_decay_min, params.common.cos_decay_min,
params.cos_decay_restart, params.common.cos_decay_restart,
params.enable_restart); params.common.enable_restart);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
@ -1623,7 +1302,7 @@ int main(int argc, char ** argv) {
size_t used_mem_after_opt = ggml_used_mem(ctx0); size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.adam_n_iter; int n_iter = params.common.adam_n_iter;
train->train_its = opt->iter; train->train_its = opt->iter;
train->train_samples += n_batch * n_iter; train->train_samples += n_batch * n_iter;
train->train_tokens += n_batch * n_tokens * n_iter; train->train_tokens += n_batch * n_tokens * n_iter;