move common train params into common/train
This commit is contained in:
parent
ee27333b16
commit
e9758ae1d2
4 changed files with 552 additions and 821 deletions
323
common/train.cpp
323
common/train.cpp
|
@ -1006,3 +1006,326 @@ std::string get_train_filename(const char * filename, const char * pattern_it, c
|
||||||
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
||||||
return replace_str(filename, pattern_it, sit.c_str());
|
return replace_str(filename, pattern_it, sit.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct train_params_common get_default_train_params_common() {
|
||||||
|
struct train_params_common params;
|
||||||
|
params.fn_train_data = "shakespeare.txt";
|
||||||
|
params.fn_checkpoint_in = "checkpoint.gguf";
|
||||||
|
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
|
||||||
|
params.pattern_fn_it = "ITERATION";
|
||||||
|
params.fn_latest = "LATEST";
|
||||||
|
|
||||||
|
params.print_usage = false;
|
||||||
|
|
||||||
|
params.save_every = 10;
|
||||||
|
|
||||||
|
params.seed = -1;
|
||||||
|
|
||||||
|
params.n_ctx = 128;
|
||||||
|
params.n_threads = 6;
|
||||||
|
params.n_batch = 8;
|
||||||
|
params.n_gradient_accumulation = 1;
|
||||||
|
|
||||||
|
params.custom_n_ctx = false;
|
||||||
|
|
||||||
|
params.use_flash = true;
|
||||||
|
params.use_checkpointing = true;
|
||||||
|
|
||||||
|
params.sample_start = "";
|
||||||
|
params.include_sample_start = false;
|
||||||
|
params.escape = false;
|
||||||
|
params.overlapping_samples = false;
|
||||||
|
params.fill_with_next_samples = false;
|
||||||
|
params.separate_with_eos = false;
|
||||||
|
params.separate_with_bos = true;
|
||||||
|
params.force_reshuffle = false;
|
||||||
|
|
||||||
|
params.opt_past = 0;
|
||||||
|
params.opt_delta = 1e-5f;
|
||||||
|
params.opt_max_no_improvement = 0;
|
||||||
|
|
||||||
|
params.warmup = 100;
|
||||||
|
params.cos_decay_steps = 1000;
|
||||||
|
params.cos_decay_restart = 1.1f;
|
||||||
|
params.cos_decay_min = 0.1f;
|
||||||
|
params.enable_restart = false;
|
||||||
|
|
||||||
|
params.adam_n_iter = 256;
|
||||||
|
params.adam_alpha = 1e-3f;
|
||||||
|
params.adam_min_alpha = 0;
|
||||||
|
params.adam_decay = 1e-1f;
|
||||||
|
params.adam_decay_min_ndim = 2;
|
||||||
|
params.adam_beta1 = 0.9f;
|
||||||
|
params.adam_beta2 = 0.999f;
|
||||||
|
params.adam_gclip = 1.0f;
|
||||||
|
params.adam_eps_f = 0.0f;
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params) {
|
||||||
|
// fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||||
|
// fprintf(stderr, "\n");
|
||||||
|
// fprintf(stderr, "options:\n");
|
||||||
|
// fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||||
|
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
|
||||||
|
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
||||||
|
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
||||||
|
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
||||||
|
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
|
||||||
|
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
|
||||||
|
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
||||||
|
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
||||||
|
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||||
|
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
||||||
|
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
|
||||||
|
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
|
||||||
|
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
|
||||||
|
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
||||||
|
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
|
||||||
|
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
|
||||||
|
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
|
||||||
|
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
|
||||||
|
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
|
||||||
|
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
|
||||||
|
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
|
||||||
|
fprintf(stderr, " --no-flash Don't use flash attention \n");
|
||||||
|
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
||||||
|
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
|
||||||
|
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
|
||||||
|
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
|
||||||
|
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
||||||
|
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
||||||
|
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
|
||||||
|
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
|
||||||
|
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
|
||||||
|
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
|
||||||
|
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
|
||||||
|
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
|
||||||
|
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
|
||||||
|
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
||||||
|
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||||
|
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
|
||||||
|
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||||
|
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
|
||||||
|
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
||||||
|
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
||||||
|
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param) {
|
||||||
|
int& i = *idx;
|
||||||
|
char * arg = argv[i];
|
||||||
|
if (arg == "--train-data") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->fn_train_data = argv[i];
|
||||||
|
} else if (arg == "--checkpoint-in") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->fn_checkpoint_in = argv[i];
|
||||||
|
} else if (arg == "--checkpoint-out") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->fn_checkpoint_out = argv[i];
|
||||||
|
} else if (arg == "--pattern-fn-it") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->pattern_fn_it = argv[i];
|
||||||
|
} else if (arg == "--fn-latest") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->fn_latest = argv[i];
|
||||||
|
} else if (arg == "--save-every") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->save_every = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-s" || arg == "--seed") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->seed = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-c" || arg == "--ctx") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->n_ctx = std::stoi(argv[i]);
|
||||||
|
params->custom_n_ctx = true;
|
||||||
|
} else if (arg == "-t" || arg == "--threads") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->n_threads = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-b" || arg == "--batch") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->n_batch = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--grad-acc") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
|
||||||
|
} else if (arg == "--sample-start") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->sample_start = std::string(argv[i]);
|
||||||
|
} else if (arg == "--escape") {
|
||||||
|
params->escape = true;
|
||||||
|
} else if (arg == "--include-sample-start") {
|
||||||
|
params->include_sample_start = true;
|
||||||
|
} else if (arg == "--overlapping-samples") {
|
||||||
|
params->overlapping_samples = true;
|
||||||
|
} else if (arg == "--fill-with-next-samples") {
|
||||||
|
params->fill_with_next_samples = true;
|
||||||
|
} else if (arg == "--separate-with-eos") {
|
||||||
|
params->separate_with_eos = true;
|
||||||
|
} else if (arg == "--separate-with-bos") {
|
||||||
|
params->separate_with_bos = true;
|
||||||
|
} else if (arg == "--no-separate-with-eos") {
|
||||||
|
params->separate_with_eos = false;
|
||||||
|
} else if (arg == "--no-separate-with-bos") {
|
||||||
|
params->separate_with_bos = false;
|
||||||
|
} else if (arg == "--force-reshuffle") {
|
||||||
|
params->force_reshuffle = true;
|
||||||
|
} else if (arg == "--no-flash") {
|
||||||
|
params->use_flash = false;
|
||||||
|
} else if (arg == "--use-flash") {
|
||||||
|
params->use_flash = true;
|
||||||
|
} else if (arg == "--no-checkpointing") {
|
||||||
|
params->use_checkpointing = false;
|
||||||
|
} else if (arg == "--use-checkpointing") {
|
||||||
|
params->use_checkpointing = true;
|
||||||
|
} else if (arg == "--warmup") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->warmup = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--cos-decay-steps") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->cos_decay_steps = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--cos-decay-restart") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->cos_decay_restart = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--cos-decay-min") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->cos_decay_min = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--enable-restart") {
|
||||||
|
params->enable_restart = true;
|
||||||
|
} else if (arg == "--disable-restart") {
|
||||||
|
params->enable_restart = false;
|
||||||
|
} else if (arg == "--opt-past") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->opt_past = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--opt-delta") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->opt_delta = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--opt-max-no-improvement") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->opt_max_no_improvement = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--adam-epsf") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_eps_f = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-iter") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_n_iter = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--adam-alpha") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_alpha = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-min-alpha") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_min_alpha = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-decay") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_decay = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-decay-min-ndim") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_decay_min_ndim = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--adam-beta1") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_beta1 = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-beta2") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_beta2 = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-gclip") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
*invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params->adam_gclip = std::stof(argv[i]);
|
||||||
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
|
params->print_usage = true;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void finish_processing_train_args(struct train_params_common * params) {
|
||||||
|
if (params->escape) {
|
||||||
|
process_escapes(params->sample_start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -26,9 +26,69 @@ struct train_state {
|
||||||
size_t shuffle_next_sample;
|
size_t shuffle_next_sample;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct train_params_common {
|
||||||
|
const char * fn_train_data;
|
||||||
|
const char * fn_checkpoint_in;
|
||||||
|
const char * fn_checkpoint_out;
|
||||||
|
const char * pattern_fn_it;
|
||||||
|
const char * fn_latest;
|
||||||
|
|
||||||
|
bool print_usage;
|
||||||
|
|
||||||
|
int save_every;
|
||||||
|
|
||||||
|
uint32_t seed;
|
||||||
|
|
||||||
|
int n_ctx;
|
||||||
|
int n_threads;
|
||||||
|
int n_batch;
|
||||||
|
int n_gradient_accumulation;
|
||||||
|
|
||||||
|
bool custom_n_ctx;
|
||||||
|
|
||||||
|
bool use_flash;
|
||||||
|
bool use_checkpointing;
|
||||||
|
|
||||||
|
std::string sample_start;
|
||||||
|
bool include_sample_start;
|
||||||
|
bool escape;
|
||||||
|
bool overlapping_samples;
|
||||||
|
bool fill_with_next_samples;
|
||||||
|
bool separate_with_eos;
|
||||||
|
bool separate_with_bos;
|
||||||
|
|
||||||
|
bool force_reshuffle;
|
||||||
|
|
||||||
|
int warmup;
|
||||||
|
int cos_decay_steps;
|
||||||
|
float cos_decay_restart;
|
||||||
|
float cos_decay_min;
|
||||||
|
bool enable_restart;
|
||||||
|
|
||||||
|
int opt_past;
|
||||||
|
float opt_delta;
|
||||||
|
int opt_max_no_improvement;
|
||||||
|
|
||||||
|
int adam_n_iter;
|
||||||
|
float adam_alpha;
|
||||||
|
float adam_min_alpha;
|
||||||
|
float adam_decay;
|
||||||
|
int adam_decay_min_ndim;
|
||||||
|
float adam_beta1;
|
||||||
|
float adam_beta2;
|
||||||
|
float adam_gclip;
|
||||||
|
float adam_eps_f;
|
||||||
|
};
|
||||||
|
|
||||||
struct train_state * init_train_state(int seed);
|
struct train_state * init_train_state(int seed);
|
||||||
void free_train_state(struct train_state * state);
|
void free_train_state(struct train_state * state);
|
||||||
|
|
||||||
|
struct train_params_common get_default_train_params_common();
|
||||||
|
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
|
||||||
|
|
||||||
|
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
|
||||||
|
void finish_processing_train_args(struct train_params_common * params);
|
||||||
|
|
||||||
struct random_normal_distribution;
|
struct random_normal_distribution;
|
||||||
struct random_uniform_distribution;
|
struct random_uniform_distribution;
|
||||||
|
|
||||||
|
|
|
@ -1197,24 +1197,10 @@ static void save_as_llama_lora(const char * filename, struct my_llama_lora * lor
|
||||||
}
|
}
|
||||||
|
|
||||||
struct train_params {
|
struct train_params {
|
||||||
|
struct train_params_common common;
|
||||||
|
|
||||||
const char * fn_model_base;
|
const char * fn_model_base;
|
||||||
const char * fn_train_data;
|
|
||||||
const char * fn_checkpoint_in;
|
|
||||||
const char * fn_checkpoint_out;
|
|
||||||
const char * fn_lora_out;
|
const char * fn_lora_out;
|
||||||
const char * pattern_fn_it;
|
|
||||||
const char * fn_latest;
|
|
||||||
|
|
||||||
int save_every;
|
|
||||||
|
|
||||||
uint32_t seed;
|
|
||||||
|
|
||||||
int n_ctx;
|
|
||||||
int n_threads;
|
|
||||||
int n_batch;
|
|
||||||
int n_gradient_accumulation;
|
|
||||||
|
|
||||||
bool custom_n_ctx;
|
|
||||||
|
|
||||||
bool only_write_lora;
|
bool only_write_lora;
|
||||||
|
|
||||||
|
@ -1255,61 +1241,13 @@ struct train_params {
|
||||||
bool custom_n_rank_tok_embeddings;
|
bool custom_n_rank_tok_embeddings;
|
||||||
bool custom_n_rank_norm;
|
bool custom_n_rank_norm;
|
||||||
bool custom_n_rank_output;
|
bool custom_n_rank_output;
|
||||||
|
|
||||||
bool use_flash;
|
|
||||||
bool use_checkpointing;
|
|
||||||
|
|
||||||
std::string sample_start;
|
|
||||||
bool include_sample_start;
|
|
||||||
bool escape;
|
|
||||||
bool overlapping_samples;
|
|
||||||
bool fill_with_next_samples;
|
|
||||||
bool separate_with_eos;
|
|
||||||
bool separate_with_bos;
|
|
||||||
|
|
||||||
bool force_reshuffle;
|
|
||||||
|
|
||||||
int warmup;
|
|
||||||
int cos_decay_steps;
|
|
||||||
float cos_decay_restart;
|
|
||||||
float cos_decay_min;
|
|
||||||
bool enable_restart;
|
|
||||||
|
|
||||||
int opt_past;
|
|
||||||
float opt_delta;
|
|
||||||
int opt_max_no_improvement;
|
|
||||||
|
|
||||||
int adam_n_iter;
|
|
||||||
float adam_alpha;
|
|
||||||
float adam_min_alpha;
|
|
||||||
float adam_decay;
|
|
||||||
int adam_decay_min_ndim;
|
|
||||||
float adam_beta1;
|
|
||||||
float adam_beta2;
|
|
||||||
float adam_gclip;
|
|
||||||
float adam_eps_f;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct train_params get_default_train_params() {
|
static struct train_params get_default_train_params() {
|
||||||
struct train_params params;
|
struct train_params params;
|
||||||
|
params.common = get_default_train_params_common();
|
||||||
params.fn_model_base = "";
|
params.fn_model_base = "";
|
||||||
params.fn_train_data = "shakespeare.txt";
|
|
||||||
params.fn_checkpoint_in = "checkpoint.gguf";
|
|
||||||
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
|
|
||||||
params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
|
params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
|
||||||
params.pattern_fn_it = "ITERATION";
|
|
||||||
params.fn_latest = "LATEST";
|
|
||||||
|
|
||||||
params.save_every = 10;
|
|
||||||
|
|
||||||
params.seed = -1;
|
|
||||||
|
|
||||||
params.n_ctx = 128;
|
|
||||||
params.n_threads = 6;
|
|
||||||
params.n_batch = 8;
|
|
||||||
params.n_gradient_accumulation = 1;
|
|
||||||
|
|
||||||
params.custom_n_ctx = false;
|
|
||||||
|
|
||||||
params.only_write_lora = false;
|
params.only_write_lora = false;
|
||||||
|
|
||||||
|
@ -1351,59 +1289,18 @@ static struct train_params get_default_train_params() {
|
||||||
params.custom_n_rank_norm = false;
|
params.custom_n_rank_norm = false;
|
||||||
params.custom_n_rank_output = false;
|
params.custom_n_rank_output = false;
|
||||||
|
|
||||||
params.use_flash = true;
|
|
||||||
params.use_checkpointing = true;
|
|
||||||
|
|
||||||
params.sample_start = "";
|
|
||||||
params.include_sample_start = false;
|
|
||||||
params.escape = false;
|
|
||||||
params.overlapping_samples = false;
|
|
||||||
params.fill_with_next_samples = false;
|
|
||||||
params.separate_with_eos = false;
|
|
||||||
params.separate_with_bos = true;
|
|
||||||
params.force_reshuffle = false;
|
|
||||||
|
|
||||||
params.opt_past = 0;
|
|
||||||
params.opt_delta = 1e-5f;
|
|
||||||
params.opt_max_no_improvement = 0;
|
|
||||||
|
|
||||||
params.warmup = 100;
|
|
||||||
params.cos_decay_steps = 1000;
|
|
||||||
params.cos_decay_restart = 1.1f;
|
|
||||||
params.cos_decay_min = 0.1f;
|
|
||||||
params.enable_restart = false;
|
|
||||||
|
|
||||||
params.adam_n_iter = 256;
|
|
||||||
params.adam_alpha = 1e-3f;
|
|
||||||
params.adam_min_alpha = 0;
|
|
||||||
params.adam_decay = 1e-1f;
|
|
||||||
params.adam_decay_min_ndim = 2;
|
|
||||||
params.adam_beta1 = 0.9f;
|
|
||||||
params.adam_beta2 = 0.999f;
|
|
||||||
params.adam_gclip = 1.0f;
|
|
||||||
params.adam_eps_f = 0.0f;
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
|
static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
|
||||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "options:\n");
|
fprintf(stderr, "options:\n");
|
||||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||||
|
|
||||||
fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base);
|
fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base);
|
||||||
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
|
|
||||||
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
|
||||||
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
|
||||||
fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
|
fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
|
||||||
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
|
||||||
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
|
|
||||||
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
|
|
||||||
fprintf(stderr, " --only-write-lora only save llama lora, don't do any training\n");
|
fprintf(stderr, " --only-write-lora only save llama lora, don't do any training\n");
|
||||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
|
||||||
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
|
||||||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
|
||||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
|
||||||
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
|
|
||||||
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
||||||
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
||||||
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
||||||
|
@ -1421,39 +1318,8 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
|
||||||
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
|
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
|
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
|
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
|
|
||||||
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
|
print_common_train_usage(argc, argv, ¶ms->common);
|
||||||
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
|
||||||
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
|
|
||||||
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
|
|
||||||
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
|
|
||||||
fprintf(stderr, " --no-flash Don't use flash attention \n");
|
|
||||||
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
|
||||||
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
|
|
||||||
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
|
|
||||||
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
|
|
||||||
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
|
||||||
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
|
||||||
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
|
|
||||||
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
|
|
||||||
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
|
|
||||||
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
|
|
||||||
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
|
|
||||||
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
|
|
||||||
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
|
|
||||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
|
||||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
|
||||||
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
|
|
||||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
|
||||||
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
|
|
||||||
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
|
||||||
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
|
||||||
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
|
@ -1468,87 +1334,27 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arg == "--model-base") {
|
if (consume_common_train_arg(argc, argv, &i, ¶ms->common, &invalid_param)) {
|
||||||
|
if (invalid_param) {
|
||||||
|
break;
|
||||||
|
} else if (params->common.print_usage) {
|
||||||
|
train_print_usage(argc, argv, &default_params);
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
} else if (arg == "--model-base") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->fn_model_base = argv[i];
|
params->fn_model_base = argv[i];
|
||||||
} else if (arg == "--train-data") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_train_data = argv[i];
|
|
||||||
} else if (arg == "--checkpoint-in") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_checkpoint_in = argv[i];
|
|
||||||
} else if (arg == "--checkpoint-out") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_checkpoint_out = argv[i];
|
|
||||||
} else if (arg == "--lora-out") {
|
} else if (arg == "--lora-out") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->fn_lora_out = argv[i];
|
params->fn_lora_out = argv[i];
|
||||||
} else if (arg == "--pattern-fn-it") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->pattern_fn_it = argv[i];
|
|
||||||
} else if (arg == "--fn-latest") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_latest = argv[i];
|
|
||||||
} else if (arg == "--save-every") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->save_every = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--only-write-lora") {
|
} else if (arg == "--only-write-lora") {
|
||||||
params->only_write_lora = true;
|
params->only_write_lora = true;
|
||||||
} else if (arg == "-s" || arg == "--seed") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->seed = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "-c" || arg == "--ctx") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_ctx = std::stoi(argv[i]);
|
|
||||||
params->custom_n_ctx = true;
|
|
||||||
} else if (arg == "-t" || arg == "--threads") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_threads = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "-b" || arg == "--batch") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_batch = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--grad-acc") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
|
|
||||||
} else if (arg == "--norm-rms-eps") {
|
} else if (arg == "--norm-rms-eps") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1667,141 +1473,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
}
|
}
|
||||||
params->n_rank_w3 = std::stoi(argv[i]);
|
params->n_rank_w3 = std::stoi(argv[i]);
|
||||||
params->custom_n_rank_w3 = true;
|
params->custom_n_rank_w3 = true;
|
||||||
} else if (arg == "--sample-start") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->sample_start = std::string(argv[i]);
|
|
||||||
} else if (arg == "--escape") {
|
|
||||||
params->escape = true;
|
|
||||||
} else if (arg == "--include-sample-start") {
|
|
||||||
params->include_sample_start = true;
|
|
||||||
} else if (arg == "--overlapping-samples") {
|
|
||||||
params->overlapping_samples = true;
|
|
||||||
} else if (arg == "--fill-with-next-samples") {
|
|
||||||
params->fill_with_next_samples = true;
|
|
||||||
} else if (arg == "--separate-with-eos") {
|
|
||||||
params->separate_with_eos = true;
|
|
||||||
} else if (arg == "--separate-with-bos") {
|
|
||||||
params->separate_with_bos = true;
|
|
||||||
} else if (arg == "--no-separate-with-eos") {
|
|
||||||
params->separate_with_eos = false;
|
|
||||||
} else if (arg == "--no-separate-with-bos") {
|
|
||||||
params->separate_with_bos = false;
|
|
||||||
} else if (arg == "--force-reshuffle") {
|
|
||||||
params->force_reshuffle = true;
|
|
||||||
} else if (arg == "--no-flash") {
|
|
||||||
params->use_flash = false;
|
|
||||||
} else if (arg == "--use-flash") {
|
|
||||||
params->use_flash = true;
|
|
||||||
} else if (arg == "--no-checkpointing") {
|
|
||||||
params->use_checkpointing = false;
|
|
||||||
} else if (arg == "--use-checkpointing") {
|
|
||||||
params->use_checkpointing = true;
|
|
||||||
} else if (arg == "--warmup") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->warmup = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--cos-decay-steps") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->cos_decay_steps = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--cos-decay-restart") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->cos_decay_restart = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--cos-decay-min") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->cos_decay_min = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--enable-restart") {
|
|
||||||
params->enable_restart = true;
|
|
||||||
} else if (arg == "--disable-restart") {
|
|
||||||
params->enable_restart = false;
|
|
||||||
} else if (arg == "--opt-past") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->opt_past = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--opt-delta") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->opt_delta = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--opt-max-no-improvement") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->opt_max_no_improvement = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--adam-epsf") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_eps_f = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-iter") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_n_iter = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--adam-alpha") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_alpha = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-min-alpha") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_min_alpha = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-decay") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_decay = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-decay-min-ndim") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_decay_min_ndim = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--adam-beta1") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_beta1 = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-beta2") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_beta2 = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-gclip") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_gclip = std::stof(argv[i]);
|
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
|
||||||
train_print_usage(argc, argv, &default_params);
|
|
||||||
exit(0);
|
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
|
@ -1813,9 +1484,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
if (params->escape) {
|
finish_processing_train_args(¶ms->common);
|
||||||
process_escapes(params->sample_start);
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1844,31 +1513,31 @@ static void save_train_files(void * vdata, struct train_state * train) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct opt_callback_data {
|
struct opt_callback_data {
|
||||||
struct train_params * params;
|
struct train_params_common * params;
|
||||||
struct train_state * train;
|
struct train_state * train;
|
||||||
save_train_files_callback save_cb;
|
save_train_files_callback save_cb;
|
||||||
void * save_data;
|
void * save_data;
|
||||||
struct llama_context * lctx;
|
struct llama_context * lctx;
|
||||||
int last_save_iter;
|
int last_save_iter;
|
||||||
llama_token * tokens_data;
|
llama_token * tokens_data;
|
||||||
size_t tokens_size;
|
size_t tokens_size;
|
||||||
size_t * samples_begin;
|
size_t * samples_begin;
|
||||||
size_t * samples_size;
|
size_t * samples_size;
|
||||||
size_t * shuffled_samples_begin;
|
size_t * shuffled_samples_begin;
|
||||||
size_t * shuffled_samples_size;
|
size_t * shuffled_samples_size;
|
||||||
size_t samples_count;
|
size_t samples_count;
|
||||||
struct ggml_tensor * tokens_input;
|
struct ggml_tensor * tokens_input;
|
||||||
struct ggml_tensor * target_probs;
|
struct ggml_tensor * target_probs;
|
||||||
int first_iter;
|
int first_iter;
|
||||||
int64_t last_time;
|
int64_t last_time;
|
||||||
double millis_per_iter;
|
double millis_per_iter;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void opt_callback(void * vdata, int accum_step, float * sched) {
|
static void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||||
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
||||||
struct train_params * params = data->params;
|
struct train_params_common * params = data->params;
|
||||||
struct train_state * train = data->train;
|
struct train_state * train = data->train;
|
||||||
struct ggml_opt_context * opt = train->opt;
|
struct ggml_opt_context * opt = train->opt;
|
||||||
int n_batch = params->n_batch;
|
int n_batch = params->n_batch;
|
||||||
int n_ctx = params->n_ctx;
|
int n_ctx = params->n_ctx;
|
||||||
|
|
||||||
|
@ -2019,11 +1688,11 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
if (params.common.seed == LLAMA_DEFAULT_SEED) {
|
||||||
params.seed = time(NULL);
|
params.common.seed = time(NULL);
|
||||||
}
|
}
|
||||||
printf("%s: seed: %u\n", __func__, params.seed);
|
printf("%s: seed: %u\n", __func__, params.common.seed);
|
||||||
srand(params.seed);
|
srand(params.common.seed);
|
||||||
|
|
||||||
struct llama_context_params llama_params = llama_context_default_params();
|
struct llama_context_params llama_params = llama_context_default_params();
|
||||||
llama_params.vocab_only = false;
|
llama_params.vocab_only = false;
|
||||||
|
@ -2033,11 +1702,11 @@ int main(int argc, char ** argv) {
|
||||||
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
|
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
|
||||||
|
|
||||||
struct my_llama_model model;
|
struct my_llama_model model;
|
||||||
init_model(lmodel, &model, params.n_ctx);
|
init_model(lmodel, &model, params.common.n_ctx);
|
||||||
|
|
||||||
struct my_llama_lora lora;
|
struct my_llama_lora lora;
|
||||||
|
|
||||||
struct train_state * train = init_train_state(params.seed);
|
struct train_state * train = init_train_state(params.common.seed);
|
||||||
struct ggml_opt_context * opt = train->opt;
|
struct ggml_opt_context * opt = train->opt;
|
||||||
|
|
||||||
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
|
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
|
||||||
|
@ -2083,30 +1752,30 @@ int main(int argc, char ** argv) {
|
||||||
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
|
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||||
opt->params.print_forward_graph = false;
|
opt->params.print_forward_graph = false;
|
||||||
opt->params.print_backward_graph = false;
|
opt->params.print_backward_graph = false;
|
||||||
opt->params.n_threads = params.n_threads;
|
opt->params.n_threads = params.common.n_threads;
|
||||||
opt->params.past = params.opt_past;
|
opt->params.past = params.common.opt_past;
|
||||||
opt->params.delta = params.opt_delta;
|
opt->params.delta = params.common.opt_delta;
|
||||||
opt->params.max_no_improvement = params.opt_max_no_improvement;
|
opt->params.max_no_improvement = params.common.opt_max_no_improvement;
|
||||||
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
|
opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
|
||||||
opt->params.adam.n_iter = params.adam_n_iter;
|
opt->params.adam.n_iter = params.common.adam_n_iter;
|
||||||
opt->params.adam.sched = 1.0f;
|
opt->params.adam.sched = 1.0f;
|
||||||
opt->params.adam.alpha = params.adam_alpha;
|
opt->params.adam.alpha = params.common.adam_alpha;
|
||||||
opt->params.adam.decay = params.adam_decay;
|
opt->params.adam.decay = params.common.adam_decay;
|
||||||
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
|
||||||
opt->params.adam.beta1 = params.adam_beta1;
|
opt->params.adam.beta1 = params.common.adam_beta1;
|
||||||
opt->params.adam.beta2 = params.adam_beta2;
|
opt->params.adam.beta2 = params.common.adam_beta2;
|
||||||
opt->params.adam.gclip = params.adam_gclip;
|
opt->params.adam.gclip = params.common.adam_gclip;
|
||||||
opt->params.adam.eps_f = params.adam_eps_f;
|
opt->params.adam.eps_f = params.common.adam_eps_f;
|
||||||
|
|
||||||
ggml_allocr * alloc = NULL;
|
ggml_allocr * alloc = NULL;
|
||||||
|
|
||||||
printf("%s: init model\n", __func__);
|
printf("%s: init model\n", __func__);
|
||||||
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, train);
|
bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
|
||||||
|
|
||||||
if (existed) {
|
if (existed) {
|
||||||
// overwrite last n_ctx with user provided n_ctx
|
// overwrite last n_ctx with user provided n_ctx
|
||||||
if (params.custom_n_ctx) {
|
if (params.common.custom_n_ctx) {
|
||||||
model.hparams.n_ctx = params.n_ctx;
|
model.hparams.n_ctx = params.common.n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool opt_param_count_changed = (
|
const bool opt_param_count_changed = (
|
||||||
|
@ -2124,7 +1793,7 @@ int main(int argc, char ** argv) {
|
||||||
|| (lora.hparams.n_rank_output != n_rank_output)
|
|| (lora.hparams.n_rank_output != n_rank_output)
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool opt_past_changed = opt->params.past != params.opt_past;
|
const bool opt_past_changed = opt->params.past != params.common.opt_past;
|
||||||
|
|
||||||
if (opt_param_count_changed) {
|
if (opt_param_count_changed) {
|
||||||
print_lora_params(&lora.hparams);
|
print_lora_params(&lora.hparams);
|
||||||
|
@ -2139,7 +1808,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
} else { // existed == false
|
} else { // existed == false
|
||||||
init_lora(&model, &lora);
|
init_lora(&model, &lora);
|
||||||
randomize_lora(&lora, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||||
if (!params.only_write_lora) {
|
if (!params.only_write_lora) {
|
||||||
ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
|
ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
|
||||||
}
|
}
|
||||||
|
@ -2159,8 +1828,8 @@ int main(int argc, char ** argv) {
|
||||||
save_train_files_data save_data;
|
save_train_files_data save_data;
|
||||||
save_data.fn_checkpoint_out = "";
|
save_data.fn_checkpoint_out = "";
|
||||||
save_data.fn_lora_out = params.fn_lora_out;
|
save_data.fn_lora_out = params.fn_lora_out;
|
||||||
save_data.pattern_fn_it = params.pattern_fn_it;
|
save_data.pattern_fn_it = params.common.pattern_fn_it;
|
||||||
save_data.fn_latest = params.fn_latest;
|
save_data.fn_latest = params.common.fn_latest;
|
||||||
save_data.model = &model;
|
save_data.model = &model;
|
||||||
save_data.lora = &lora;
|
save_data.lora = &lora;
|
||||||
|
|
||||||
|
@ -2175,7 +1844,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
int n_tokens = model.hparams.n_ctx;
|
int n_tokens = model.hparams.n_ctx;
|
||||||
int n_vocab = model.hparams.n_vocab;
|
int n_vocab = model.hparams.n_vocab;
|
||||||
int n_batch = params.n_batch;
|
int n_batch = params.common.n_batch;
|
||||||
|
|
||||||
printf("%s: opt iter %d\n", __func__, opt->iter);
|
printf("%s: opt iter %d\n", __func__, opt->iter);
|
||||||
|
|
||||||
|
@ -2215,7 +1884,7 @@ int main(int argc, char ** argv) {
|
||||||
size_t estimated_compute_size_wo_data = (
|
size_t estimated_compute_size_wo_data = (
|
||||||
ggml_tensor_overhead()*GGML_MAX_NODES*2
|
ggml_tensor_overhead()*GGML_MAX_NODES*2
|
||||||
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
|
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
|
||||||
params.use_checkpointing ? 3 : 2
|
params.common.use_checkpointing ? 3 : 2
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
struct ggml_init_params ctx_compute_params = {
|
struct ggml_init_params ctx_compute_params = {
|
||||||
|
@ -2242,7 +1911,7 @@ int main(int argc, char ** argv) {
|
||||||
gf = ggml_new_graph(ctx_compute);
|
gf = ggml_new_graph(ctx_compute);
|
||||||
gf->order = (enum ggml_cgraph_eval_order) order;
|
gf->order = (enum ggml_cgraph_eval_order) order;
|
||||||
gb = ggml_new_graph(ctx_compute);
|
gb = ggml_new_graph(ctx_compute);
|
||||||
gb_tmp = params.use_checkpointing
|
gb_tmp = params.common.use_checkpointing
|
||||||
? ggml_new_graph(ctx_compute)
|
? ggml_new_graph(ctx_compute)
|
||||||
: NULL;
|
: NULL;
|
||||||
loss = llama_build_lora_finetune_graphs(
|
loss = llama_build_lora_finetune_graphs(
|
||||||
|
@ -2250,8 +1919,8 @@ int main(int argc, char ** argv) {
|
||||||
gf, gb, gb_tmp,
|
gf, gb, gb_tmp,
|
||||||
&logits, tokens_input, target_probs,
|
&logits, tokens_input, target_probs,
|
||||||
n_tokens, n_batch,
|
n_tokens, n_batch,
|
||||||
params.use_flash,
|
params.common.use_flash,
|
||||||
params.use_checkpointing
|
params.common.use_checkpointing
|
||||||
);
|
);
|
||||||
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
|
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
|
||||||
if (max_compute_size < best_compute_size) {
|
if (max_compute_size < best_compute_size) {
|
||||||
|
@ -2275,7 +1944,7 @@ int main(int argc, char ** argv) {
|
||||||
gf = ggml_new_graph(ctx_compute);
|
gf = ggml_new_graph(ctx_compute);
|
||||||
gf->order = best_order;
|
gf->order = best_order;
|
||||||
gb = ggml_new_graph(ctx_compute);
|
gb = ggml_new_graph(ctx_compute);
|
||||||
gb_tmp = params.use_checkpointing
|
gb_tmp = params.common.use_checkpointing
|
||||||
? ggml_new_graph(ctx_compute)
|
? ggml_new_graph(ctx_compute)
|
||||||
: NULL;
|
: NULL;
|
||||||
loss = llama_build_lora_finetune_graphs(
|
loss = llama_build_lora_finetune_graphs(
|
||||||
|
@ -2283,8 +1952,8 @@ int main(int argc, char ** argv) {
|
||||||
gf, gb, gb_tmp,
|
gf, gb, gb_tmp,
|
||||||
&logits, tokens_input, target_probs,
|
&logits, tokens_input, target_probs,
|
||||||
n_tokens, n_batch,
|
n_tokens, n_batch,
|
||||||
params.use_flash,
|
params.common.use_flash,
|
||||||
params.use_checkpointing
|
params.common.use_checkpointing
|
||||||
);
|
);
|
||||||
ggml_allocr_free(alloc);
|
ggml_allocr_free(alloc);
|
||||||
|
|
||||||
|
@ -2294,10 +1963,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<size_t> train_samples_size;
|
std::vector<size_t> train_samples_size;
|
||||||
printf("%s: tokenize training data\n", __func__);
|
printf("%s: tokenize training data\n", __func__);
|
||||||
tokenize_file(lctx,
|
tokenize_file(lctx,
|
||||||
params.fn_train_data,
|
params.common.fn_train_data,
|
||||||
params.sample_start,
|
params.common.sample_start,
|
||||||
params.include_sample_start,
|
params.common.include_sample_start,
|
||||||
params.overlapping_samples,
|
params.common.overlapping_samples,
|
||||||
n_tokens,
|
n_tokens,
|
||||||
train_tokens,
|
train_tokens,
|
||||||
train_samples_begin,
|
train_samples_begin,
|
||||||
|
@ -2318,16 +1987,16 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
|
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
|
||||||
|
|
||||||
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
|
size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
|
||||||
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
|
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
|
||||||
if (changed_train_data) {
|
if (changed_train_data) {
|
||||||
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
|
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
|
||||||
}
|
}
|
||||||
if (params.force_reshuffle) {
|
if (params.common.force_reshuffle) {
|
||||||
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
|
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
|
||||||
}
|
}
|
||||||
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) {
|
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
|
||||||
train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed);
|
train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
|
||||||
train->shuffle_sample_count = train_samples_size.size();
|
train->shuffle_sample_count = train_samples_size.size();
|
||||||
train->shuffle_next_sample = 0;
|
train->shuffle_next_sample = 0;
|
||||||
train->shuffle_samples_hash = shuffle_samples_hash;
|
train->shuffle_samples_hash = shuffle_samples_hash;
|
||||||
|
@ -2347,15 +2016,15 @@ int main(int argc, char ** argv) {
|
||||||
printf("%s: begin training\n", __func__);
|
printf("%s: begin training\n", __func__);
|
||||||
|
|
||||||
save_train_files_data save_data;
|
save_train_files_data save_data;
|
||||||
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
|
save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
|
||||||
save_data.fn_lora_out = params.fn_lora_out;
|
save_data.fn_lora_out = params.fn_lora_out;
|
||||||
save_data.pattern_fn_it = params.pattern_fn_it;
|
save_data.pattern_fn_it = params.common.pattern_fn_it;
|
||||||
save_data.fn_latest = params.fn_latest;
|
save_data.fn_latest = params.common.fn_latest;
|
||||||
save_data.model = &model;
|
save_data.model = &model;
|
||||||
save_data.lora = &lora;
|
save_data.lora = &lora;
|
||||||
|
|
||||||
struct opt_callback_data opt_cb_data;
|
struct opt_callback_data opt_cb_data;
|
||||||
opt_cb_data.params = ¶ms;
|
opt_cb_data.params = ¶ms.common;
|
||||||
opt_cb_data.train = train;
|
opt_cb_data.train = train;
|
||||||
opt_cb_data.save_cb = &save_train_files;
|
opt_cb_data.save_cb = &save_train_files;
|
||||||
opt_cb_data.save_data = &save_data;
|
opt_cb_data.save_data = &save_data;
|
||||||
|
@ -2375,7 +2044,7 @@ int main(int argc, char ** argv) {
|
||||||
opt_cb_data.millis_per_iter = 0.0;
|
opt_cb_data.millis_per_iter = 0.0;
|
||||||
|
|
||||||
// measure required memory for work buffer
|
// measure required memory for work buffer
|
||||||
size_t max_work_size = ggml_graph_plan(gb, params.n_threads).work_size + GGML_OBJECT_SIZE;
|
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
|
||||||
printf("%s: max_work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
|
printf("%s: max_work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
|
||||||
|
|
||||||
// context for work buffer
|
// context for work buffer
|
||||||
|
|
|
@ -692,17 +692,10 @@ static void save_checkpoint_file(const char * filename, const char * fn_vocab_mo
|
||||||
}
|
}
|
||||||
|
|
||||||
struct train_params {
|
struct train_params {
|
||||||
|
struct train_params_common common;
|
||||||
|
|
||||||
const char * fn_vocab_model;
|
const char * fn_vocab_model;
|
||||||
const char * fn_train_data;
|
|
||||||
const char * fn_checkpoint_in;
|
|
||||||
const char * fn_checkpoint_out;
|
|
||||||
const char * fn_model_out;
|
const char * fn_model_out;
|
||||||
const char * pattern_fn_it;
|
|
||||||
const char * fn_latest;
|
|
||||||
|
|
||||||
int save_every;
|
|
||||||
|
|
||||||
uint32_t seed;
|
|
||||||
|
|
||||||
int n_ctx;
|
int n_ctx;
|
||||||
int n_embd;
|
int n_embd;
|
||||||
|
@ -710,10 +703,7 @@ struct train_params {
|
||||||
int n_layer;
|
int n_layer;
|
||||||
int n_ff;
|
int n_ff;
|
||||||
|
|
||||||
int n_threads;
|
|
||||||
int n_examples;
|
int n_examples;
|
||||||
int n_batch;
|
|
||||||
int n_gradient_accumulation;
|
|
||||||
|
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
float rope_freq_base;
|
float rope_freq_base;
|
||||||
|
@ -721,40 +711,8 @@ struct train_params {
|
||||||
|
|
||||||
int print_info_interval;
|
int print_info_interval;
|
||||||
|
|
||||||
bool use_flash;
|
|
||||||
bool use_checkpointing;
|
|
||||||
bool use_alloc;
|
bool use_alloc;
|
||||||
|
|
||||||
std::string sample_start;
|
|
||||||
bool include_sample_start;
|
|
||||||
bool escape;
|
|
||||||
bool overlapping_samples;
|
|
||||||
bool fill_with_next_samples;
|
|
||||||
bool separate_with_eos;
|
|
||||||
bool separate_with_bos;
|
|
||||||
|
|
||||||
bool force_reshuffle;
|
|
||||||
|
|
||||||
int warmup;
|
|
||||||
int cos_decay_steps;
|
|
||||||
float cos_decay_restart;
|
|
||||||
float cos_decay_min;
|
|
||||||
bool enable_restart;
|
|
||||||
|
|
||||||
int opt_past;
|
|
||||||
float opt_delta;
|
|
||||||
int opt_max_no_improvement;
|
|
||||||
|
|
||||||
int adam_n_iter;
|
|
||||||
float adam_alpha;
|
|
||||||
float adam_min_alpha;
|
|
||||||
float adam_decay;
|
|
||||||
int adam_decay_min_ndim;
|
|
||||||
float adam_beta1;
|
|
||||||
float adam_beta2;
|
|
||||||
float adam_gclip;
|
|
||||||
float adam_eps_f;
|
|
||||||
|
|
||||||
int mem_model_gb;
|
int mem_model_gb;
|
||||||
int mem_compute_gb;
|
int mem_compute_gb;
|
||||||
int mem_compute0_gb;
|
int mem_compute0_gb;
|
||||||
|
@ -762,17 +720,9 @@ struct train_params {
|
||||||
|
|
||||||
struct train_params get_default_train_params() {
|
struct train_params get_default_train_params() {
|
||||||
struct train_params params;
|
struct train_params params;
|
||||||
|
params.common = get_default_train_params_common();
|
||||||
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
|
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
|
||||||
params.fn_train_data = "shakespeare.txt";
|
|
||||||
params.fn_checkpoint_in = "checkpoint.bin";
|
|
||||||
params.fn_checkpoint_out = "checkpoint.bin";
|
|
||||||
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
||||||
params.pattern_fn_it = "ITERATION";
|
|
||||||
params.fn_latest = "LATEST";
|
|
||||||
|
|
||||||
params.save_every = 10;
|
|
||||||
|
|
||||||
params.seed = -1;
|
|
||||||
|
|
||||||
params.n_ctx = 128;
|
params.n_ctx = 128;
|
||||||
params.n_embd = 256;
|
params.n_embd = 256;
|
||||||
|
@ -780,10 +730,7 @@ struct train_params get_default_train_params() {
|
||||||
params.n_layer = 16;
|
params.n_layer = 16;
|
||||||
params.n_ff = 768;
|
params.n_ff = 768;
|
||||||
|
|
||||||
params.n_threads = 6;
|
|
||||||
params.n_examples = 1;
|
params.n_examples = 1;
|
||||||
params.n_batch = 8;
|
|
||||||
params.n_gradient_accumulation = 1;
|
|
||||||
|
|
||||||
params.f_norm_rms_eps = 1e-5f;
|
params.f_norm_rms_eps = 1e-5f;
|
||||||
params.rope_freq_base = 10000.0f;
|
params.rope_freq_base = 10000.0f;
|
||||||
|
@ -791,60 +738,22 @@ struct train_params get_default_train_params() {
|
||||||
|
|
||||||
params.print_info_interval = 1;
|
params.print_info_interval = 1;
|
||||||
|
|
||||||
params.use_flash = true;
|
|
||||||
params.use_checkpointing = true;
|
|
||||||
params.use_alloc = true;
|
params.use_alloc = true;
|
||||||
|
|
||||||
params.sample_start = "";
|
|
||||||
params.include_sample_start = false;
|
|
||||||
params.escape = false;
|
|
||||||
params.overlapping_samples = false;
|
|
||||||
params.fill_with_next_samples = false;
|
|
||||||
params.separate_with_eos = false;
|
|
||||||
params.separate_with_bos = true;
|
|
||||||
params.force_reshuffle = false;
|
|
||||||
|
|
||||||
params.opt_past = 0;
|
|
||||||
params.opt_delta = 1e-5f;
|
|
||||||
params.opt_max_no_improvement = 0;
|
|
||||||
|
|
||||||
params.warmup = 100;
|
|
||||||
params.cos_decay_steps = 1000;
|
|
||||||
params.cos_decay_restart = 1.1f;
|
|
||||||
params.cos_decay_min = 0.1f;
|
|
||||||
params.enable_restart = false;
|
|
||||||
|
|
||||||
params.adam_n_iter = 256;
|
|
||||||
params.adam_alpha = 1e-3f;
|
|
||||||
params.adam_min_alpha = 0;
|
|
||||||
params.adam_decay = 1e-1f;
|
|
||||||
params.adam_decay_min_ndim = 2;
|
|
||||||
params.adam_beta1 = 0.9f;
|
|
||||||
params.adam_beta2 = 0.999f;
|
|
||||||
params.adam_gclip = 1.0f;
|
|
||||||
params.adam_eps_f = 0.0f;
|
|
||||||
|
|
||||||
params.mem_model_gb = 2;
|
params.mem_model_gb = 2;
|
||||||
params.mem_compute_gb = 24;
|
params.mem_compute_gb = 24;
|
||||||
params.mem_compute0_gb = 8;
|
params.mem_compute0_gb = 8;
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
|
static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
|
||||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "options:\n");
|
fprintf(stderr, "options:\n");
|
||||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||||
|
|
||||||
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
|
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
|
||||||
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
|
|
||||||
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
|
||||||
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
|
||||||
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
|
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
|
||||||
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
|
||||||
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
|
|
||||||
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
|
|
||||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
|
||||||
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
|
||||||
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
|
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
|
||||||
fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
|
fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
|
||||||
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
|
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
|
||||||
|
@ -852,49 +761,15 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
|
||||||
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
||||||
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
||||||
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
||||||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
|
||||||
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
||||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
|
||||||
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
|
|
||||||
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
||||||
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
|
|
||||||
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
|
|
||||||
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
|
||||||
fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
|
|
||||||
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
|
|
||||||
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
|
|
||||||
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
|
|
||||||
fprintf(stderr, " --no-flash Don't use flash attention \n");
|
|
||||||
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
|
||||||
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
|
|
||||||
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
|
|
||||||
fprintf(stderr, " --no-alloc Don't use allocator\n");
|
fprintf(stderr, " --no-alloc Don't use allocator\n");
|
||||||
fprintf(stderr, " --use-alloc Use allocator (default)\n");
|
fprintf(stderr, " --use-alloc Use allocator (default)\n");
|
||||||
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
|
|
||||||
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
|
||||||
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
|
||||||
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
|
|
||||||
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
|
|
||||||
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
|
|
||||||
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
|
|
||||||
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
|
|
||||||
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
|
|
||||||
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
|
|
||||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
|
||||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
|
||||||
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
|
|
||||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
|
||||||
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
|
|
||||||
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
|
||||||
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
|
||||||
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
|
||||||
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
||||||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||||
fprintf(stderr, "\n");
|
|
||||||
|
print_common_train_usage(argc, argv, ¶ms->common);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
|
@ -909,66 +784,25 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arg == "--vocab-model") {
|
if (consume_common_train_arg(argc, argv, &i, ¶ms->common, &invalid_param)) {
|
||||||
|
if (invalid_param) {
|
||||||
|
break;
|
||||||
|
} else if (params->common.print_usage) {
|
||||||
|
train_print_usage(argc, argv, &default_params);
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
} else if (arg == "--vocab-model") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->fn_vocab_model = argv[i];
|
params->fn_vocab_model = argv[i];
|
||||||
} else if (arg == "--train-data") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_train_data = argv[i];
|
|
||||||
} else if (arg == "--checkpoint-in") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_checkpoint_in = argv[i];
|
|
||||||
} else if (arg == "--checkpoint-out") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_checkpoint_out = argv[i];
|
|
||||||
} else if (arg == "--model-out") {
|
} else if (arg == "--model-out") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->fn_model_out = argv[i];
|
params->fn_model_out = argv[i];
|
||||||
} else if (arg == "--pattern-fn-it") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->pattern_fn_it = argv[i];
|
|
||||||
} else if (arg == "--fn-latest") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->fn_latest = argv[i];
|
|
||||||
} else if (arg == "--save-every") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->save_every = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "-s" || arg == "--seed") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->seed = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "-c" || arg == "--ctx") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_ctx = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--embd") {
|
} else if (arg == "--embd") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1011,24 +845,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->rope_freq_scale = std::stof(argv[i]);
|
params->rope_freq_scale = std::stof(argv[i]);
|
||||||
} else if (arg == "-t" || arg == "--threads") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_threads = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "-b" || arg == "--batch") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_batch = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--grad-acc") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
|
|
||||||
} else if (arg == "-n" || arg == "--examples") {
|
} else if (arg == "-n" || arg == "--examples") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1041,142 +857,10 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->print_info_interval = std::stoi(argv[i]);
|
params->print_info_interval = std::stoi(argv[i]);
|
||||||
} else if (arg == "--sample-start") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->sample_start = std::string(argv[i]);
|
|
||||||
} else if (arg == "--escape") {
|
|
||||||
params->escape = true;
|
|
||||||
} else if (arg == "--include-sample-start") {
|
|
||||||
params->include_sample_start = true;
|
|
||||||
} else if (arg == "--overlapping-samples") {
|
|
||||||
params->overlapping_samples = true;
|
|
||||||
} else if (arg == "--fill-with-next-samples") {
|
|
||||||
params->fill_with_next_samples = true;
|
|
||||||
} else if (arg == "--separate-with-eos") {
|
|
||||||
params->separate_with_eos = true;
|
|
||||||
} else if (arg == "--separate-with-bos") {
|
|
||||||
params->separate_with_bos = true;
|
|
||||||
} else if (arg == "--no-separate-with-eos") {
|
|
||||||
params->separate_with_eos = false;
|
|
||||||
} else if (arg == "--no-separate-with-bos") {
|
|
||||||
params->separate_with_bos = false;
|
|
||||||
} else if (arg == "--force-reshuffle") {
|
|
||||||
params->force_reshuffle = true;
|
|
||||||
} else if (arg == "--no-flash") {
|
|
||||||
params->use_flash = false;
|
|
||||||
} else if (arg == "--use-flash") {
|
|
||||||
params->use_flash = true;
|
|
||||||
} else if (arg == "--no-checkpointing") {
|
|
||||||
params->use_checkpointing = false;
|
|
||||||
} else if (arg == "--use-checkpointing") {
|
|
||||||
params->use_checkpointing = true;
|
|
||||||
} else if (arg == "--no-alloc") {
|
} else if (arg == "--no-alloc") {
|
||||||
params->use_alloc = false;
|
params->use_alloc = false;
|
||||||
} else if (arg == "--use-alloc") {
|
} else if (arg == "--use-alloc") {
|
||||||
params->use_alloc = true;
|
params->use_alloc = true;
|
||||||
} else if (arg == "--warmup") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->warmup = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--cos-decay-steps") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->cos_decay_steps = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--cos-decay-restart") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->cos_decay_restart = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--cos-decay-min") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->cos_decay_min = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--enable-restart") {
|
|
||||||
params->enable_restart = true;
|
|
||||||
} else if (arg == "--disable-restart") {
|
|
||||||
params->enable_restart = false;
|
|
||||||
} else if (arg == "--opt-past") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->opt_past = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--opt-delta") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->opt_delta = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--opt-max-no-improvement") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->opt_max_no_improvement = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--adam-epsf") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_eps_f = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-iter") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_n_iter = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--adam-alpha") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_alpha = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-min-alpha") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_min_alpha = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-decay") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_decay = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-decay-min-ndim") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_decay_min_ndim = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "--adam-beta1") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_beta1 = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-beta2") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_beta2 = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--adam-gclip") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->adam_gclip = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--mem-model") {
|
} else if (arg == "--mem-model") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1195,9 +879,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->mem_compute0_gb = std::stoi(argv[i]);
|
params->mem_compute0_gb = std::stoi(argv[i]);
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
|
||||||
train_print_usage(argc, argv, &default_params);
|
|
||||||
exit(0);
|
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
|
@ -1209,9 +890,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
if (params->escape) {
|
finish_processing_train_args(¶ms->common);
|
||||||
process_escapes(params->sample_start);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1241,32 +920,32 @@ static void save_train_files(void * vdata, struct train_state * train) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct opt_callback_data {
|
struct opt_callback_data {
|
||||||
struct train_params * params;
|
struct train_params_common * params;
|
||||||
struct train_state * train;
|
struct train_state * train;
|
||||||
save_train_files_callback save_cb;
|
save_train_files_callback save_cb;
|
||||||
void * save_data;
|
void * save_data;
|
||||||
struct llama_context * lctx;
|
struct llama_context * lctx;
|
||||||
int last_save_iter;
|
int last_save_iter;
|
||||||
llama_token * tokens_data;
|
llama_token * tokens_data;
|
||||||
size_t tokens_size;
|
size_t tokens_size;
|
||||||
size_t * samples_begin;
|
size_t * samples_begin;
|
||||||
size_t * samples_size;
|
size_t * samples_size;
|
||||||
size_t * shuffled_samples_begin;
|
size_t * shuffled_samples_begin;
|
||||||
size_t * shuffled_samples_size;
|
size_t * shuffled_samples_size;
|
||||||
size_t samples_count;
|
size_t samples_count;
|
||||||
struct ggml_tensor * tokens_input;
|
struct ggml_tensor * tokens_input;
|
||||||
struct ggml_tensor * target_logits;
|
struct ggml_tensor * target_logits;
|
||||||
struct ggml_tensor * target_probs;
|
struct ggml_tensor * target_probs;
|
||||||
int first_iter;
|
int first_iter;
|
||||||
int64_t last_time;
|
int64_t last_time;
|
||||||
double millis_per_iter;
|
double millis_per_iter;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void opt_callback(void * vdata, int accum_step, float * sched) {
|
static void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||||
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
||||||
struct train_params * params = data->params;
|
struct train_params_common * params = data->params;
|
||||||
struct train_state * train = data->train;
|
struct train_state * train = data->train;
|
||||||
struct ggml_opt_context * opt = train->opt;
|
struct ggml_opt_context * opt = train->opt;
|
||||||
int n_batch = params->n_batch;
|
int n_batch = params->n_batch;
|
||||||
int n_ctx = params->n_ctx;
|
int n_ctx = params->n_ctx;
|
||||||
|
|
||||||
|
@ -1385,11 +1064,11 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
if (params.common.seed == LLAMA_DEFAULT_SEED) {
|
||||||
params.seed = time(NULL);
|
params.common.seed = time(NULL);
|
||||||
}
|
}
|
||||||
printf("%s: seed: %u\n", __func__, params.seed);
|
printf("%s: seed: %u\n", __func__, params.common.seed);
|
||||||
srand(params.seed);
|
srand(params.common.seed);
|
||||||
|
|
||||||
struct llama_context_params llama_params = llama_context_default_params();
|
struct llama_context_params llama_params = llama_context_default_params();
|
||||||
llama_params.vocab_only = true;
|
llama_params.vocab_only = true;
|
||||||
|
@ -1399,7 +1078,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
struct my_llama_model model;
|
struct my_llama_model model;
|
||||||
model.hparams.n_vocab = llama_n_vocab(lctx);
|
model.hparams.n_vocab = llama_n_vocab(lctx);
|
||||||
model.hparams.n_ctx = params.n_ctx;
|
model.hparams.n_ctx = params.common.n_ctx;
|
||||||
model.hparams.n_embd = params.n_embd;
|
model.hparams.n_embd = params.n_embd;
|
||||||
model.hparams.n_head = params.n_head;
|
model.hparams.n_head = params.n_head;
|
||||||
model.hparams.n_layer = params.n_layer;
|
model.hparams.n_layer = params.n_layer;
|
||||||
|
@ -1421,34 +1100,34 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
int n_tokens = model.hparams.n_ctx;
|
int n_tokens = model.hparams.n_ctx;
|
||||||
int n_vocab = model.hparams.n_vocab;
|
int n_vocab = model.hparams.n_vocab;
|
||||||
int n_batch = params.n_batch;
|
int n_batch = params.common.n_batch;
|
||||||
|
|
||||||
struct train_state * train = init_train_state(params.seed);
|
struct train_state * train = init_train_state(params.common.seed);
|
||||||
struct ggml_opt_context * opt = train->opt;
|
struct ggml_opt_context * opt = train->opt;
|
||||||
|
|
||||||
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||||
opt_params_adam.print_forward_graph = false;
|
opt_params_adam.print_forward_graph = false;
|
||||||
opt_params_adam.print_backward_graph = false;
|
opt_params_adam.print_backward_graph = false;
|
||||||
opt_params_adam.n_threads = params.n_threads;
|
opt_params_adam.n_threads = params.common.n_threads;
|
||||||
opt_params_adam.past = params.opt_past;
|
opt_params_adam.past = params.common.opt_past;
|
||||||
opt_params_adam.delta = params.opt_delta;
|
opt_params_adam.delta = params.common.opt_delta;
|
||||||
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
opt_params_adam.max_no_improvement = params.common.opt_max_no_improvement;
|
||||||
opt_params_adam.n_gradient_accumulation = params.n_gradient_accumulation;
|
opt_params_adam.n_gradient_accumulation = params.common.n_gradient_accumulation;
|
||||||
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
opt_params_adam.adam.n_iter = params.common.adam_n_iter;
|
||||||
opt_params_adam.adam.sched = 1.0f;
|
opt_params_adam.adam.sched = 1.0f;
|
||||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
opt_params_adam.adam.alpha = params.common.adam_alpha;
|
||||||
opt_params_adam.adam.decay = params.adam_decay;
|
opt_params_adam.adam.decay = params.common.adam_decay;
|
||||||
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
opt_params_adam.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
|
||||||
opt_params_adam.adam.beta1 = params.adam_beta1;
|
opt_params_adam.adam.beta1 = params.common.adam_beta1;
|
||||||
opt_params_adam.adam.beta2 = params.adam_beta2;
|
opt_params_adam.adam.beta2 = params.common.adam_beta2;
|
||||||
opt_params_adam.adam.gclip = params.adam_gclip;
|
opt_params_adam.adam.gclip = params.common.adam_gclip;
|
||||||
opt_params_adam.adam.eps_f = params.adam_eps_f;
|
opt_params_adam.adam.eps_f = params.common.adam_eps_f;
|
||||||
|
|
||||||
opt->ctx = model.ctx;
|
opt->ctx = model.ctx;
|
||||||
opt->params = opt_params_adam;
|
opt->params = opt_params_adam;
|
||||||
|
|
||||||
printf("%s: init model\n", __func__);
|
printf("%s: init model\n", __func__);
|
||||||
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, train);
|
bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
|
||||||
if (!existed) {
|
if (!existed) {
|
||||||
init_model(&model);
|
init_model(&model);
|
||||||
}
|
}
|
||||||
|
@ -1461,7 +1140,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
bool from_scratch = !existed;
|
bool from_scratch = !existed;
|
||||||
if (from_scratch) {
|
if (from_scratch) {
|
||||||
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
|
printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
|
||||||
|
@ -1485,10 +1164,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<size_t> train_samples_size;
|
std::vector<size_t> train_samples_size;
|
||||||
printf("%s: tokenize training data\n", __func__);
|
printf("%s: tokenize training data\n", __func__);
|
||||||
tokenize_file(lctx,
|
tokenize_file(lctx,
|
||||||
params.fn_train_data,
|
params.common.fn_train_data,
|
||||||
params.sample_start,
|
params.common.sample_start,
|
||||||
params.include_sample_start,
|
params.common.include_sample_start,
|
||||||
params.overlapping_samples,
|
params.common.overlapping_samples,
|
||||||
n_tokens,
|
n_tokens,
|
||||||
train_tokens,
|
train_tokens,
|
||||||
train_samples_begin,
|
train_samples_begin,
|
||||||
|
@ -1497,16 +1176,16 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
|
printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
|
||||||
|
|
||||||
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
|
size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
|
||||||
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
|
const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
|
||||||
if (changed_train_data) {
|
if (changed_train_data) {
|
||||||
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
|
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
|
||||||
}
|
}
|
||||||
if (params.force_reshuffle) {
|
if (params.common.force_reshuffle) {
|
||||||
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
|
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
|
||||||
}
|
}
|
||||||
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) {
|
if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
|
||||||
train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed);
|
train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
|
||||||
train->shuffle_sample_count = train_samples_size.size();
|
train->shuffle_sample_count = train_samples_size.size();
|
||||||
train->shuffle_next_sample = 0;
|
train->shuffle_next_sample = 0;
|
||||||
train->shuffle_samples_hash = shuffle_samples_hash;
|
train->shuffle_samples_hash = shuffle_samples_hash;
|
||||||
|
@ -1525,15 +1204,15 @@ int main(int argc, char ** argv) {
|
||||||
printf("%s: begin training\n", __func__);
|
printf("%s: begin training\n", __func__);
|
||||||
|
|
||||||
save_train_files_data save_data;
|
save_train_files_data save_data;
|
||||||
save_data.fn_checkpoint_out = params.fn_checkpoint_out;
|
save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
|
||||||
save_data.fn_model_out = params.fn_model_out;
|
save_data.fn_model_out = params.fn_model_out;
|
||||||
save_data.fn_vocab_model = params.fn_vocab_model;
|
save_data.fn_vocab_model = params.fn_vocab_model;
|
||||||
save_data.pattern_fn_it = params.pattern_fn_it;
|
save_data.pattern_fn_it = params.common.pattern_fn_it;
|
||||||
save_data.fn_latest = params.fn_latest;
|
save_data.fn_latest = params.common.fn_latest;
|
||||||
save_data.model = &model;
|
save_data.model = &model;
|
||||||
|
|
||||||
struct opt_callback_data opt_cb_data;
|
struct opt_callback_data opt_cb_data;
|
||||||
opt_cb_data.params = ¶ms;
|
opt_cb_data.params = ¶ms.common;
|
||||||
opt_cb_data.train = train;
|
opt_cb_data.train = train;
|
||||||
opt_cb_data.save_cb = &save_train_files;
|
opt_cb_data.save_cb = &save_train_files;
|
||||||
opt_cb_data.save_data = &save_data;
|
opt_cb_data.save_data = &save_data;
|
||||||
|
@ -1587,7 +1266,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
|
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
|
||||||
struct ggml_cgraph * gb_tmp = params.use_checkpointing
|
struct ggml_cgraph * gb_tmp = params.common.use_checkpointing
|
||||||
? ggml_new_graph(ctx0)
|
? ggml_new_graph(ctx0)
|
||||||
: NULL;
|
: NULL;
|
||||||
|
|
||||||
|
@ -1601,21 +1280,21 @@ int main(int argc, char ** argv) {
|
||||||
gf, gb, gb_tmp,
|
gf, gb, gb_tmp,
|
||||||
&logits, tokens_input, target_probs,
|
&logits, tokens_input, target_probs,
|
||||||
n_tokens, n_batch,
|
n_tokens, n_batch,
|
||||||
params.use_flash,
|
params.common.use_flash,
|
||||||
params.use_checkpointing
|
params.common.use_checkpointing
|
||||||
);
|
);
|
||||||
|
|
||||||
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
||||||
|
|
||||||
opt->params.adam.sched = learning_schedule(
|
opt->params.adam.sched = learning_schedule(
|
||||||
opt->iter,
|
opt->iter,
|
||||||
params.warmup,
|
params.common.warmup,
|
||||||
params.cos_decay_steps,
|
params.common.cos_decay_steps,
|
||||||
params.adam_alpha,
|
params.common.adam_alpha,
|
||||||
params.adam_min_alpha,
|
params.common.adam_min_alpha,
|
||||||
params.cos_decay_min,
|
params.common.cos_decay_min,
|
||||||
params.cos_decay_restart,
|
params.common.cos_decay_restart,
|
||||||
params.enable_restart);
|
params.common.enable_restart);
|
||||||
|
|
||||||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||||
|
|
||||||
|
@ -1623,7 +1302,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||||
|
|
||||||
int n_iter = params.adam_n_iter;
|
int n_iter = params.common.adam_n_iter;
|
||||||
train->train_its = opt->iter;
|
train->train_its = opt->iter;
|
||||||
train->train_samples += n_batch * n_iter;
|
train->train_samples += n_batch * n_iter;
|
||||||
train->train_tokens += n_batch * n_tokens * n_iter;
|
train->train_tokens += n_batch * n_tokens * n_iter;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue