remove lbfgs related train parameters

This commit is contained in:
xaedes 2023-09-16 15:59:38 +02:00
parent ab56b63b27
commit 00b656f6db
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 22 additions and 81 deletions

View file

@ -2081,7 +2081,6 @@ struct train_params {
bool custom_n_rank_norm;
bool custom_n_rank_output;
bool use_adam;
bool use_flash;
bool use_checkpointing;
@ -2095,7 +2094,6 @@ struct train_params {
bool force_reshuffle;
// only adam
int warmup;
int cos_decay_steps;
float cos_decay_restart;
@ -2106,7 +2104,6 @@ struct train_params {
float opt_delta;
int opt_max_no_improvement;
int lbfgs_n_iter;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
@ -2179,7 +2176,6 @@ static struct train_params get_default_train_params() {
params.custom_n_rank_norm = false;
params.custom_n_rank_output = false;
params.use_adam = true;
params.use_flash = true;
params.use_checkpointing = true;
@ -2196,14 +2192,12 @@ static struct train_params get_default_train_params() {
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
// only adam
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.lbfgs_n_iter = 256;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
@ -2262,7 +2256,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
@ -2285,7 +2278,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, "\n");
}
@ -2524,10 +2516,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--use-lbfgs") {
params->use_adam = false;
} else if (arg == "--use-adam") {
params->use_adam = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
@ -2636,12 +2624,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
@ -2728,7 +2710,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) {
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter;
const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter;
@ -2943,35 +2925,23 @@ int main(int argc, char ** argv) {
lora.hparams.n_rank_output = n_rank_output;
// set opt params from command line
if (params.use_adam) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.n_threads = params.n_threads;
opt->params.past = params.opt_past;
opt->params.delta = params.opt_delta;
opt->params.max_no_improvement = params.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
opt->params.adam.n_iter = params.adam_n_iter;
opt->params.adam.sched = 1.0f;
opt->params.adam.alpha = params.adam_alpha;
opt->params.adam.decay = params.adam_decay;
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim;
opt->params.adam.beta1 = params.adam_beta1;
opt->params.adam.beta2 = params.adam_beta2;
opt->params.adam.gclip = params.adam_gclip;
opt->params.adam.eps_f = params.adam_eps_f;
} else {
opt->params = ggml_opt_default_params(GGML_OPT_LBFGS);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.n_threads = params.n_threads;
opt->params.past = params.opt_past;
opt->params.delta = params.opt_delta;
opt->params.max_no_improvement = params.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
opt->params.lbfgs.n_iter = params.lbfgs_n_iter;
}
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.n_threads = params.n_threads;
opt->params.past = params.opt_past;
opt->params.delta = params.opt_delta;
opt->params.max_no_improvement = params.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
opt->params.adam.n_iter = params.adam_n_iter;
opt->params.adam.sched = 1.0f;
opt->params.adam.alpha = params.adam_alpha;
opt->params.adam.decay = params.adam_decay;
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim;
opt->params.adam.beta1 = params.adam_beta1;
opt->params.adam.beta2 = params.adam_beta2;
opt->params.adam.gclip = params.adam_gclip;
opt->params.adam.eps_f = params.adam_eps_f;
ggml_allocr * alloc = NULL;

View file

@ -1620,7 +1620,6 @@ struct train_params {
int print_info_interval;
bool use_adam;
bool use_flash;
bool use_checkpointing;
bool use_alloc;
@ -1635,7 +1634,6 @@ struct train_params {
bool force_reshuffle;
// only adam
int warmup;
int cos_decay_steps;
float cos_decay_restart;
@ -1646,7 +1644,6 @@ struct train_params {
float opt_delta;
int opt_max_no_improvement;
int lbfgs_n_iter;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
@ -1693,7 +1690,6 @@ struct train_params get_default_train_params() {
params.print_info_interval = 1;
params.use_adam = true;
params.use_flash = true;
params.use_checkpointing = true;
params.use_alloc = true;
@ -1711,14 +1707,12 @@ struct train_params get_default_train_params() {
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
// only adam
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f;
params.enable_restart = false;
params.lbfgs_n_iter = 256;
params.adam_n_iter = 256;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0;
@ -1772,8 +1766,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
@ -1798,7 +1790,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
@ -1973,10 +1964,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") {
params->force_reshuffle = true;
} else if (arg == "--use-lbfgs") {
params->use_adam = false;
} else if (arg == "--use-adam") {
params->use_adam = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
@ -2089,12 +2076,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "--mem-model") {
if (++i >= argc) {
invalid_param = true;
@ -2200,7 +2181,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) {
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter;
const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter;
@ -2364,7 +2345,6 @@ int main(int argc, char ** argv) {
memset(opt, 0, sizeof(struct ggml_opt_context));
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = params.n_threads;
@ -2382,17 +2362,8 @@ int main(int argc, char ** argv) {
opt_params_adam.adam.gclip = params.adam_gclip;
opt_params_adam.adam.eps_f = params.adam_eps_f;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = params.n_threads;
opt_params_lbfgs.past = params.opt_past;
opt_params_lbfgs.delta = params.opt_delta;
opt_params_lbfgs.max_no_improvement = params.opt_max_no_improvement;
opt_params_lbfgs.n_gradient_accumulation = params.n_gradient_accumulation;
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
opt->ctx = model.ctx;
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
opt->params = opt_params_adam;
printf("%s: init model\n", __func__);
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt);
@ -2401,7 +2372,7 @@ int main(int argc, char ** argv) {
}
set_param_model(&model);
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
opt->params = opt_params_adam;
opt->iter = model.train_its;
printf("%s: opt iter %d\n", __func__, opt->iter);
@ -2563,7 +2534,7 @@ int main(int argc, char ** argv) {
size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
int n_iter = params.adam_n_iter;
model.train_its = opt->iter;
model.train_samples += n_batch * n_iter;
model.train_tokens += n_batch * n_tokens * n_iter;