move common train params into common/train

This commit is contained in:
xaedes 2023-09-16 18:45:59 +02:00
parent ee27333b16
commit e9758ae1d2
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 552 additions and 821 deletions

View file

@ -1006,3 +1006,326 @@ std::string get_train_filename(const char * filename, const char * pattern_it, c
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
return replace_str(filename, pattern_it, sit.c_str());
}
struct train_params_common get_default_train_params_common() {
struct train_params_common params;
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.gguf";
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
params.pattern_fn_it = "ITERATION";
params.fn_latest = "LATEST";
params.print_usage = false;
params.save_every = 10;
params.seed = -1;
params.n_ctx = 128;
params.n_threads = 6;
params.n_batch = 8;
params.n_gradient_accumulation = 1;
params.custom_n_ctx = false;
params.use_flash = true;
params.use_checkpointing = true;
params.sample_start = "";
params.include_sample_start = false;
params.escape = false;
params.overlapping_samples = false;
params.fill_with_next_samples = false;
params.separate_with_eos = false;
params.separate_with_bos = true;
params.force_reshuffle = false;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
params.adam_decay = 1e-1f;
params.adam_decay_min_ndim = 2;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
return params;
}
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params) {
// fprintf(stderr, "usage: %s [options]\n", argv[0]);
// fprintf(stderr, "\n");
// fprintf(stderr, "options:\n");
// fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, "\n");
}
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param) {
int& i = *idx;
char * arg = argv[i];
if (arg == "--train-data") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--pattern-fn-it") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->pattern_fn_it = argv[i];
} else if (arg == "--fn-latest") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->fn_latest = argv[i];
} else if (arg == "--save-every") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->save_every = std::stoi(argv[i]);
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_ctx = std::stoi(argv[i]);
params->custom_n_ctx = true;
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "--grad-acc") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "--sample-start") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->sample_start = std::string(argv[i]);
} else if (arg == "--escape") {
params->escape = true;
} else if (arg == "--include-sample-start") {
params->include_sample_start = true;
} else if (arg == "--overlapping-samples") {
params->overlapping_samples = true;
} else if (arg == "--fill-with-next-samples") {
params->fill_with_next_samples = true;
} else if (arg == "--separate-with-eos") {
params->separate_with_eos = true;
} else if (arg == "--separate-with-bos") {
params->separate_with_bos = true;
} else if (arg == "--no-separate-with-eos") {
params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->cos_decay_steps = std::stoi(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-min") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->cos_decay_min = std::stof(argv[i]);
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-decay-min-ndim") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_decay_min_ndim = std::stoi(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-h" || arg == "--help") {
params->print_usage = true;
return true;
} else {
return false;
}
return true;
}
void finish_processing_train_args(struct train_params_common * params) {
if (params->escape) {
process_escapes(params->sample_start);
}
}

View file

@ -26,9 +26,69 @@ struct train_state {
size_t shuffle_next_sample;
};
struct train_params_common {
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * pattern_fn_it;
const char * fn_latest;
bool print_usage;
int save_every;
uint32_t seed;
int n_ctx;
int n_threads;
int n_batch;
int n_gradient_accumulation;
bool custom_n_ctx;
bool use_flash;
bool use_checkpointing;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
};
struct train_state * init_train_state(int seed);
void free_train_state(struct train_state * state);
struct train_params_common get_default_train_params_common();
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
void finish_processing_train_args(struct train_params_common * params);
struct random_normal_distribution;
struct random_uniform_distribution;

View file

@ -1197,24 +1197,10 @@ static void save_as_llama_lora(const char * filename, struct my_llama_lora * lor
}
struct train_params {
struct train_params_common common;
const char * fn_model_base;
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * fn_lora_out;
const char * pattern_fn_it;
const char * fn_latest;
int save_every;
uint32_t seed;
int n_ctx;
int n_threads;
int n_batch;
int n_gradient_accumulation;
bool custom_n_ctx;
bool only_write_lora;
@ -1255,61 +1241,13 @@ struct train_params {
bool custom_n_rank_tok_embeddings;
bool custom_n_rank_norm;
bool custom_n_rank_output;
bool use_flash;
bool use_checkpointing;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
};
static struct train_params get_default_train_params() {
struct train_params params;
params.common = get_default_train_params_common();
params.fn_model_base = "";
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.gguf";
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
params.pattern_fn_it = "ITERATION";
params.fn_latest = "LATEST";
params.save_every = 10;
params.seed = -1;
params.n_ctx = 128;
params.n_threads = 6;
params.n_batch = 8;
params.n_gradient_accumulation = 1;
params.custom_n_ctx = false;
params.only_write_lora = false;
@ -1351,59 +1289,18 @@ static struct train_params get_default_train_params() {
params.custom_n_rank_norm = false;
params.custom_n_rank_output = false;
params.use_flash = true;
params.use_checkpointing = true;
params.sample_start = "";
params.include_sample_start = false;
params.escape = false;
params.overlapping_samples = false;
params.fill_with_next_samples = false;
params.separate_with_eos = false;
params.separate_with_bos = true;
params.force_reshuffle = false;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
params.adam_decay = 1e-1f;
params.adam_decay_min_ndim = 2;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
return params;
}
static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base);
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
fprintf(stderr, " --only-write-lora only save llama lora, don't do any training\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
@ -1421,39 +1318,8 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, "\n");
print_common_train_usage(argc, argv, &params->common);
}
static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
@ -1468,87 +1334,27 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg == "--model-base") {
if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
if (invalid_param) {
break;
} else if (params->common.print_usage) {
train_print_usage(argc, argv, &default_params);
exit(0);
}
} else if (arg == "--model-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_model_base = argv[i];
} else if (arg == "--train-data") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--lora-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_lora_out = argv[i];
} else if (arg == "--pattern-fn-it") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->pattern_fn_it = argv[i];
} else if (arg == "--fn-latest") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_latest = argv[i];
} else if (arg == "--save-every") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->save_every = std::stoi(argv[i]);
} else if (arg == "--only-write-lora") {
params->only_write_lora = true;
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_ctx = std::stoi(argv[i]);
params->custom_n_ctx = true;
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "--grad-acc") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "--norm-rms-eps") {
if (++i >= argc) {
invalid_param = true;
@ -1667,141 +1473,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
}
params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true;
} else if (arg == "--sample-start") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->sample_start = std::string(argv[i]);
} else if (arg == "--escape") {
params->escape = true;
} else if (arg == "--include-sample-start") {
params->include_sample_start = true;
} else if (arg == "--overlapping-samples") {
params->overlapping_samples = true;
} else if (arg == "--fill-with-next-samples") {
params->fill_with_next_samples = true;
} else if (arg == "--separate-with-eos") {
params->separate_with_eos = true;
} else if (arg == "--separate-with-bos") {
params->separate_with_bos = true;
} else if (arg == "--no-separate-with-eos") {
params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_steps = std::stof(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-min") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_min = std::stof(argv[i]);
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-decay-min-ndim") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay_min_ndim = std::stoi(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params);
@ -1813,9 +1484,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
train_print_usage(argc, argv, &default_params);
exit(1);
}
if (params->escape) {
process_escapes(params->sample_start);
}
finish_processing_train_args(&params->common);
return true;
}
@ -1844,7 +1513,7 @@ static void save_train_files(void * vdata, struct train_state * train) {
}
struct opt_callback_data {
struct train_params * params;
struct train_params_common * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
@ -1866,7 +1535,7 @@ struct opt_callback_data {
static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params;
struct train_params_common * params = data->params;
struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch;
@ -2019,11 +1688,11 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
if (params.common.seed == LLAMA_DEFAULT_SEED) {
params.common.seed = time(NULL);
}
printf("%s: seed: %u\n", __func__, params.seed);
srand(params.seed);
printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = false;
@ -2033,11 +1702,11 @@ int main(int argc, char ** argv) {
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct my_llama_model model;
init_model(lmodel, &model, params.n_ctx);
init_model(lmodel, &model, params.common.n_ctx);
struct my_llama_lora lora;
struct train_state * train = init_train_state(params.seed);
struct train_state * train = init_train_state(params.common.seed);
struct ggml_opt_context * opt = train->opt;
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
@ -2083,30 +1752,30 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.n_threads = params.n_threads;
opt->params.past = params.opt_past;
opt->params.delta = params.opt_delta;
opt->params.max_no_improvement = params.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
opt->params.adam.n_iter = params.adam_n_iter;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
opt->params.max_no_improvement = params.common.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
opt->params.adam.n_iter = params.common.adam_n_iter;
opt->params.adam.sched = 1.0f;
opt->params.adam.alpha = params.adam_alpha;
opt->params.adam.decay = params.adam_decay;
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim;
opt->params.adam.beta1 = params.adam_beta1;
opt->params.adam.beta2 = params.adam_beta2;
opt->params.adam.gclip = params.adam_gclip;
opt->params.adam.eps_f = params.adam_eps_f;
opt->params.adam.alpha = params.common.adam_alpha;
opt->params.adam.decay = params.common.adam_decay;
opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
opt->params.adam.beta1 = params.common.adam_beta1;
opt->params.adam.beta2 = params.common.adam_beta2;
opt->params.adam.gclip = params.common.adam_gclip;
opt->params.adam.eps_f = params.common.adam_eps_f;
ggml_allocr * alloc = NULL;
printf("%s: init model\n", __func__);
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, train);
bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
if (existed) {
// overwrite last n_ctx with user provided n_ctx
if (params.custom_n_ctx) {
model.hparams.n_ctx = params.n_ctx;
if (params.common.custom_n_ctx) {
model.hparams.n_ctx = params.common.n_ctx;
}
const bool opt_param_count_changed = (
@ -2124,7 +1793,7 @@ int main(int argc, char ** argv) {
|| (lora.hparams.n_rank_output != n_rank_output)
);
const bool opt_past_changed = opt->params.past != params.opt_past;
const bool opt_past_changed = opt->params.past != params.common.opt_past;
if (opt_param_count_changed) {
print_lora_params(&lora.hparams);
@ -2139,7 +1808,7 @@ int main(int argc, char ** argv) {
}
} else { // existed == false
init_lora(&model, &lora);
randomize_lora(&lora, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
if (!params.only_write_lora) {
ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
}
@ -2159,8 +1828,8 @@ int main(int argc, char ** argv) {
save_train_files_data save_data;
save_data.fn_checkpoint_out = "";
save_data.fn_lora_out = params.fn_lora_out;
save_data.pattern_fn_it = params.pattern_fn_it;
save_data.fn_latest = params.fn_latest;
save_data.pattern_fn_it = params.common.pattern_fn_it;
save_data.fn_latest = params.common.fn_latest;
save_data.model = &model;
save_data.lora = &lora;
@ -2175,7 +1844,7 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.n_batch;
int n_batch = params.common.n_batch;
printf("%s: opt iter %d\n", __func__, opt->iter);
@ -2215,7 +1884,7 @@ int main(int argc, char ** argv) {
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
params.use_checkpointing ? 3 : 2
params.common.use_checkpointing ? 3 : 2
)
);
struct ggml_init_params ctx_compute_params = {
@ -2242,7 +1911,7 @@ int main(int argc, char ** argv) {
gf = ggml_new_graph(ctx_compute);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute);
gb_tmp = params.use_checkpointing
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
: NULL;
loss = llama_build_lora_finetune_graphs(
@ -2250,8 +1919,8 @@ int main(int argc, char ** argv) {
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.use_flash,
params.use_checkpointing
params.common.use_flash,
params.common.use_checkpointing
);
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
if (max_compute_size < best_compute_size) {
@ -2275,7 +1944,7 @@ int main(int argc, char ** argv) {
gf = ggml_new_graph(ctx_compute);
gf->order = best_order;
gb = ggml_new_graph(ctx_compute);
gb_tmp = params.use_checkpointing
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
: NULL;
loss = llama_build_lora_finetune_graphs(
@ -2283,8 +1952,8 @@ int main(int argc, char ** argv) {
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.use_flash,
params.use_checkpointing
params.common.use_flash,
params.common.use_checkpointing
);
ggml_allocr_free(alloc);
@ -2294,10 +1963,10 @@ int main(int argc, char ** argv) {
std::vector<size_t> train_samples_size;
printf("%s: tokenize training data\n", __func__);
tokenize_file(lctx,
params.fn_train_data,
params.sample_start,
params.include_sample_start,
params.overlapping_samples,
params.common.fn_train_data,
params.common.sample_start,
params.common.include_sample_start,
params.common.overlapping_samples,
n_tokens,
train_tokens,
train_samples_begin,
@ -2318,16 +1987,16 @@ int main(int argc, char ** argv) {
}
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
if (changed_train_data) {
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
}
if (params.force_reshuffle) {
if (params.common.force_reshuffle) {
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
}
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) {
train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed);
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
train->shuffle_sample_count = train_samples_size.size();
train->shuffle_next_sample = 0;
train->shuffle_samples_hash = shuffle_samples_hash;
@ -2347,15 +2016,15 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__);
save_train_files_data save_data;
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
save_data.fn_lora_out = params.fn_lora_out;
save_data.pattern_fn_it = params.pattern_fn_it;
save_data.fn_latest = params.fn_latest;
save_data.pattern_fn_it = params.common.pattern_fn_it;
save_data.fn_latest = params.common.fn_latest;
save_data.model = &model;
save_data.lora = &lora;
struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params;
opt_cb_data.params = &params.common;
opt_cb_data.train = train;
opt_cb_data.save_cb = &save_train_files;
opt_cb_data.save_data = &save_data;
@ -2375,7 +2044,7 @@ int main(int argc, char ** argv) {
opt_cb_data.millis_per_iter = 0.0;
// measure required memory for work buffer
size_t max_work_size = ggml_graph_plan(gb, params.n_threads).work_size + GGML_OBJECT_SIZE;
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
printf("%s: max_work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
// context for work buffer

View file

@ -692,17 +692,10 @@ static void save_checkpoint_file(const char * filename, const char * fn_vocab_mo
}
struct train_params {
struct train_params_common common;
const char * fn_vocab_model;
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * fn_model_out;
const char * pattern_fn_it;
const char * fn_latest;
int save_every;
uint32_t seed;
int n_ctx;
int n_embd;
@ -710,10 +703,7 @@ struct train_params {
int n_layer;
int n_ff;
int n_threads;
int n_examples;
int n_batch;
int n_gradient_accumulation;
float f_norm_rms_eps;
float rope_freq_base;
@ -721,40 +711,8 @@ struct train_params {
int print_info_interval;
bool use_flash;
bool use_checkpointing;
bool use_alloc;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
int mem_model_gb;
int mem_compute_gb;
int mem_compute0_gb;
@ -762,17 +720,9 @@ struct train_params {
struct train_params get_default_train_params() {
struct train_params params;
params.common = get_default_train_params_common();
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.bin";
params.fn_checkpoint_out = "checkpoint.bin";
params.fn_model_out = "ggml-checkpoint-f32.bin";
params.pattern_fn_it = "ITERATION";
params.fn_latest = "LATEST";
params.save_every = 10;
params.seed = -1;
params.n_ctx = 128;
params.n_embd = 256;
@ -780,10 +730,7 @@ struct train_params get_default_train_params() {
params.n_layer = 16;
params.n_ff = 768;
params.n_threads = 6;
params.n_examples = 1;
params.n_batch = 8;
params.n_gradient_accumulation = 1;
params.f_norm_rms_eps = 1e-5f;
params.rope_freq_base = 10000.0f;
@ -791,60 +738,22 @@ struct train_params get_default_train_params() {
params.print_info_interval = 1;
params.use_flash = true;
params.use_checkpointing = true;
params.use_alloc = true;
params.sample_start = "";
params.include_sample_start = false;
params.escape = false;
params.overlapping_samples = false;
params.fill_with_next_samples = false;
params.separate_with_eos = false;
params.separate_with_bos = true;
params.force_reshuffle = false;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
params.adam_decay = 1e-1f;
params.adam_decay_min_ndim = 2;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
params.mem_model_gb = 2;
params.mem_compute_gb = 24;
params.mem_compute0_gb = 8;
return params;
}
static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
@ -852,49 +761,15 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
fprintf(stderr, " --no-alloc Don't use allocator\n");
fprintf(stderr, " --use-alloc Use allocator (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
fprintf(stderr, "\n");
print_common_train_usage(argc, argv, &params->common);
}
static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
@ -909,66 +784,25 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg == "--vocab-model") {
if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
if (invalid_param) {
break;
} else if (params->common.print_usage) {
train_print_usage(argc, argv, &default_params);
exit(0);
}
} else if (arg == "--vocab-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_vocab_model = argv[i];
} else if (arg == "--train-data") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--model-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_model_out = argv[i];
} else if (arg == "--pattern-fn-it") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->pattern_fn_it = argv[i];
} else if (arg == "--fn-latest") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_latest = argv[i];
} else if (arg == "--save-every") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->save_every = std::stoi(argv[i]);
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_ctx = std::stoi(argv[i]);
} else if (arg == "--embd") {
if (++i >= argc) {
invalid_param = true;
@ -1011,24 +845,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break;
}
params->rope_freq_scale = std::stof(argv[i]);
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "--grad-acc") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "-n" || arg == "--examples") {
if (++i >= argc) {
invalid_param = true;
@ -1041,142 +857,10 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break;
}
params->print_info_interval = std::stoi(argv[i]);
} else if (arg == "--sample-start") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->sample_start = std::string(argv[i]);
} else if (arg == "--escape") {
params->escape = true;
} else if (arg == "--include-sample-start") {
params->include_sample_start = true;
} else if (arg == "--overlapping-samples") {
params->overlapping_samples = true;
} else if (arg == "--fill-with-next-samples") {
params->fill_with_next_samples = true;
} else if (arg == "--separate-with-eos") {
params->separate_with_eos = true;
} else if (arg == "--separate-with-bos") {
params->separate_with_bos = true;
} else if (arg == "--no-separate-with-eos") {
params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--no-alloc") {
params->use_alloc = false;
} else if (arg == "--use-alloc") {
params->use_alloc = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_steps = std::stof(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-min") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_min = std::stof(argv[i]);
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-decay-min-ndim") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay_min_ndim = std::stoi(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--mem-model") {
if (++i >= argc) {
invalid_param = true;
@ -1195,9 +879,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break;
}
params->mem_compute0_gb = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params);
@ -1209,9 +890,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
train_print_usage(argc, argv, &default_params);
exit(1);
}
if (params->escape) {
process_escapes(params->sample_start);
}
finish_processing_train_args(&params->common);
return true;
}
@ -1241,7 +920,7 @@ static void save_train_files(void * vdata, struct train_state * train) {
}
struct opt_callback_data {
struct train_params * params;
struct train_params_common * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
@ -1264,7 +943,7 @@ struct opt_callback_data {
static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params;
struct train_params_common * params = data->params;
struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch;
@ -1385,11 +1064,11 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
if (params.common.seed == LLAMA_DEFAULT_SEED) {
params.common.seed = time(NULL);
}
printf("%s: seed: %u\n", __func__, params.seed);
srand(params.seed);
printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
@ -1399,7 +1078,7 @@ int main(int argc, char ** argv) {
struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx;
model.hparams.n_ctx = params.common.n_ctx;
model.hparams.n_embd = params.n_embd;
model.hparams.n_head = params.n_head;
model.hparams.n_layer = params.n_layer;
@ -1421,34 +1100,34 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.n_batch;
int n_batch = params.common.n_batch;
struct train_state * train = init_train_state(params.seed);
struct train_state * train = init_train_state(params.common.seed);
struct ggml_opt_context * opt = train->opt;
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = params.n_threads;
opt_params_adam.past = params.opt_past;
opt_params_adam.delta = params.opt_delta;
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
opt_params_adam.n_gradient_accumulation = params.n_gradient_accumulation;
opt_params_adam.adam.n_iter = params.adam_n_iter;
opt_params_adam.n_threads = params.common.n_threads;
opt_params_adam.past = params.common.opt_past;
opt_params_adam.delta = params.common.opt_delta;
opt_params_adam.max_no_improvement = params.common.opt_max_no_improvement;
opt_params_adam.n_gradient_accumulation = params.common.n_gradient_accumulation;
opt_params_adam.adam.n_iter = params.common.adam_n_iter;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.adam_alpha;
opt_params_adam.adam.decay = params.adam_decay;
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
opt_params_adam.adam.beta1 = params.adam_beta1;
opt_params_adam.adam.beta2 = params.adam_beta2;
opt_params_adam.adam.gclip = params.adam_gclip;
opt_params_adam.adam.eps_f = params.adam_eps_f;
opt_params_adam.adam.alpha = params.common.adam_alpha;
opt_params_adam.adam.decay = params.common.adam_decay;
opt_params_adam.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
opt_params_adam.adam.beta1 = params.common.adam_beta1;
opt_params_adam.adam.beta2 = params.common.adam_beta2;
opt_params_adam.adam.gclip = params.common.adam_gclip;
opt_params_adam.adam.eps_f = params.common.adam_eps_f;
opt->ctx = model.ctx;
opt->params = opt_params_adam;
printf("%s: init model\n", __func__);
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, train);
bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
if (!existed) {
init_model(&model);
}
@ -1461,7 +1140,7 @@ int main(int argc, char ** argv) {
bool from_scratch = !existed;
if (from_scratch) {
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
}
printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
@ -1485,10 +1164,10 @@ int main(int argc, char ** argv) {
std::vector<size_t> train_samples_size;
printf("%s: tokenize training data\n", __func__);
tokenize_file(lctx,
params.fn_train_data,
params.sample_start,
params.include_sample_start,
params.overlapping_samples,
params.common.fn_train_data,
params.common.sample_start,
params.common.include_sample_start,
params.common.overlapping_samples,
n_tokens,
train_tokens,
train_samples_begin,
@ -1497,16 +1176,16 @@ int main(int argc, char ** argv) {
printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
if (changed_train_data) {
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
}
if (params.force_reshuffle) {
if (params.common.force_reshuffle) {
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
}
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) {
train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed);
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
train->shuffle_sample_count = train_samples_size.size();
train->shuffle_next_sample = 0;
train->shuffle_samples_hash = shuffle_samples_hash;
@ -1525,15 +1204,15 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__);
save_train_files_data save_data;
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
save_data.fn_model_out = params.fn_model_out;
save_data.fn_vocab_model = params.fn_vocab_model;
save_data.pattern_fn_it = params.pattern_fn_it;
save_data.fn_latest = params.fn_latest;
save_data.pattern_fn_it = params.common.pattern_fn_it;
save_data.fn_latest = params.common.fn_latest;
save_data.model = &model;
struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params;
opt_cb_data.params = &params.common;
opt_cb_data.train = train;
opt_cb_data.save_cb = &save_train_files;
opt_cb_data.save_data = &save_data;
@ -1587,7 +1266,7 @@ int main(int argc, char ** argv) {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
struct ggml_cgraph * gb_tmp = params.use_checkpointing
struct ggml_cgraph * gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx0)
: NULL;
@ -1601,21 +1280,21 @@ int main(int argc, char ** argv) {
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.use_flash,
params.use_checkpointing
params.common.use_flash,
params.common.use_checkpointing
);
size_t used_mem_before_opt = ggml_used_mem(ctx0);
opt->params.adam.sched = learning_schedule(
opt->iter,
params.warmup,
params.cos_decay_steps,
params.adam_alpha,
params.adam_min_alpha,
params.cos_decay_min,
params.cos_decay_restart,
params.enable_restart);
params.common.warmup,
params.common.cos_decay_steps,
params.common.adam_alpha,
params.common.adam_min_alpha,
params.common.cos_decay_min,
params.common.cos_decay_restart,
params.common.enable_restart);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
@ -1623,7 +1302,7 @@ int main(int argc, char ** argv) {
size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.adam_n_iter;
int n_iter = params.common.adam_n_iter;
train->train_its = opt->iter;
train->train_samples += n_batch * n_iter;
train->train_tokens += n_batch * n_tokens * n_iter;