remove lbfgs related train parameters

This commit is contained in:
xaedes 2023-09-16 15:59:38 +02:00
parent ab56b63b27
commit 00b656f6db
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 22 additions and 81 deletions

View file

@ -2081,7 +2081,6 @@ struct train_params {
bool custom_n_rank_norm; bool custom_n_rank_norm;
bool custom_n_rank_output; bool custom_n_rank_output;
bool use_adam;
bool use_flash; bool use_flash;
bool use_checkpointing; bool use_checkpointing;
@ -2095,7 +2094,6 @@ struct train_params {
bool force_reshuffle; bool force_reshuffle;
// only adam
int warmup; int warmup;
int cos_decay_steps; int cos_decay_steps;
float cos_decay_restart; float cos_decay_restart;
@ -2106,7 +2104,6 @@ struct train_params {
float opt_delta; float opt_delta;
int opt_max_no_improvement; int opt_max_no_improvement;
int lbfgs_n_iter;
int adam_n_iter; int adam_n_iter;
float adam_alpha; float adam_alpha;
float adam_min_alpha; float adam_min_alpha;
@ -2179,7 +2176,6 @@ static struct train_params get_default_train_params() {
params.custom_n_rank_norm = false; params.custom_n_rank_norm = false;
params.custom_n_rank_output = false; params.custom_n_rank_output = false;
params.use_adam = true;
params.use_flash = true; params.use_flash = true;
params.use_checkpointing = true; params.use_checkpointing = true;
@ -2196,14 +2192,12 @@ static struct train_params get_default_train_params() {
params.opt_delta = 1e-5f; params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0; params.opt_max_no_improvement = 0;
// only adam
params.warmup = 100; params.warmup = 100;
params.cos_decay_steps = 1000; params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f; params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f; params.cos_decay_min = 0.1f;
params.enable_restart = false; params.enable_restart = false;
params.lbfgs_n_iter = 256;
params.adam_n_iter = 256; params.adam_n_iter = 256;
params.adam_alpha = 1e-3f; params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0; params.adam_min_alpha = 0;
@ -2262,7 +2256,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention \n"); fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n"); fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
@ -2285,7 +2278,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -2524,10 +2516,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
params->separate_with_bos = false; params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") { } else if (arg == "--force-reshuffle") {
params->force_reshuffle = true; params->force_reshuffle = true;
} else if (arg == "--use-lbfgs") {
params->use_adam = false;
} else if (arg == "--use-adam") {
params->use_adam = true;
} else if (arg == "--no-flash") { } else if (arg == "--no-flash") {
params->use_flash = false; params->use_flash = false;
} else if (arg == "--use-flash") { } else if (arg == "--use-flash") {
@ -2636,12 +2624,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break; break;
} }
params->adam_gclip = std::stof(argv[i]); params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);
exit(0); exit(0);
@ -2728,7 +2710,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
double remaining_millis = 0.0; double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) { if (data->millis_per_iter > 0.0) {
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter; const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter; const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter; const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter; remaining_millis = remaining_iter * data->millis_per_iter;
@ -2943,7 +2925,6 @@ int main(int argc, char ** argv) {
lora.hparams.n_rank_output = n_rank_output; lora.hparams.n_rank_output = n_rank_output;
// set opt params from command line // set opt params from command line
if (params.use_adam) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false; opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false; opt->params.print_backward_graph = false;
@ -2961,17 +2942,6 @@ int main(int argc, char ** argv) {
opt->params.adam.beta2 = params.adam_beta2; opt->params.adam.beta2 = params.adam_beta2;
opt->params.adam.gclip = params.adam_gclip; opt->params.adam.gclip = params.adam_gclip;
opt->params.adam.eps_f = params.adam_eps_f; opt->params.adam.eps_f = params.adam_eps_f;
} else {
opt->params = ggml_opt_default_params(GGML_OPT_LBFGS);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.n_threads = params.n_threads;
opt->params.past = params.opt_past;
opt->params.delta = params.opt_delta;
opt->params.max_no_improvement = params.opt_max_no_improvement;
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
opt->params.lbfgs.n_iter = params.lbfgs_n_iter;
}
ggml_allocr * alloc = NULL; ggml_allocr * alloc = NULL;

View file

@ -1620,7 +1620,6 @@ struct train_params {
int print_info_interval; int print_info_interval;
bool use_adam;
bool use_flash; bool use_flash;
bool use_checkpointing; bool use_checkpointing;
bool use_alloc; bool use_alloc;
@ -1635,7 +1634,6 @@ struct train_params {
bool force_reshuffle; bool force_reshuffle;
// only adam
int warmup; int warmup;
int cos_decay_steps; int cos_decay_steps;
float cos_decay_restart; float cos_decay_restart;
@ -1646,7 +1644,6 @@ struct train_params {
float opt_delta; float opt_delta;
int opt_max_no_improvement; int opt_max_no_improvement;
int lbfgs_n_iter;
int adam_n_iter; int adam_n_iter;
float adam_alpha; float adam_alpha;
float adam_min_alpha; float adam_min_alpha;
@ -1693,7 +1690,6 @@ struct train_params get_default_train_params() {
params.print_info_interval = 1; params.print_info_interval = 1;
params.use_adam = true;
params.use_flash = true; params.use_flash = true;
params.use_checkpointing = true; params.use_checkpointing = true;
params.use_alloc = true; params.use_alloc = true;
@ -1711,14 +1707,12 @@ struct train_params get_default_train_params() {
params.opt_delta = 1e-5f; params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0; params.opt_max_no_improvement = 0;
// only adam
params.warmup = 100; params.warmup = 100;
params.cos_decay_steps = 1000; params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f; params.cos_decay_restart = 1.1f;
params.cos_decay_min = 0.1f; params.cos_decay_min = 0.1f;
params.enable_restart = false; params.enable_restart = false;
params.lbfgs_n_iter = 256;
params.adam_n_iter = 256; params.adam_n_iter = 256;
params.adam_alpha = 1e-3f; params.adam_alpha = 1e-3f;
params.adam_min_alpha = 0; params.adam_min_alpha = 0;
@ -1772,8 +1766,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention \n"); fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n"); fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
@ -1798,7 +1790,6 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb); fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
@ -1973,10 +1964,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
params->separate_with_bos = false; params->separate_with_bos = false;
} else if (arg == "--force-reshuffle") { } else if (arg == "--force-reshuffle") {
params->force_reshuffle = true; params->force_reshuffle = true;
} else if (arg == "--use-lbfgs") {
params->use_adam = false;
} else if (arg == "--use-adam") {
params->use_adam = true;
} else if (arg == "--no-flash") { } else if (arg == "--no-flash") {
params->use_flash = false; params->use_flash = false;
} else if (arg == "--use-flash") { } else if (arg == "--use-flash") {
@ -2089,12 +2076,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
break; break;
} }
params->adam_gclip = std::stof(argv[i]); params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "--mem-model") { } else if (arg == "--mem-model") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -2200,7 +2181,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
double remaining_millis = 0.0; double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) { if (data->millis_per_iter > 0.0) {
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter; const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter; const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter; const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter; remaining_millis = remaining_iter * data->millis_per_iter;
@ -2364,7 +2345,6 @@ int main(int argc, char ** argv) {
memset(opt, 0, sizeof(struct ggml_opt_context)); memset(opt, 0, sizeof(struct ggml_opt_context));
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false; opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false; opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = params.n_threads; opt_params_adam.n_threads = params.n_threads;
@ -2382,17 +2362,8 @@ int main(int argc, char ** argv) {
opt_params_adam.adam.gclip = params.adam_gclip; opt_params_adam.adam.gclip = params.adam_gclip;
opt_params_adam.adam.eps_f = params.adam_eps_f; opt_params_adam.adam.eps_f = params.adam_eps_f;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = params.n_threads;
opt_params_lbfgs.past = params.opt_past;
opt_params_lbfgs.delta = params.opt_delta;
opt_params_lbfgs.max_no_improvement = params.opt_max_no_improvement;
opt_params_lbfgs.n_gradient_accumulation = params.n_gradient_accumulation;
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
opt->ctx = model.ctx; opt->ctx = model.ctx;
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; opt->params = opt_params_adam;
printf("%s: init model\n", __func__); printf("%s: init model\n", __func__);
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt); bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt);
@ -2401,7 +2372,7 @@ int main(int argc, char ** argv) {
} }
set_param_model(&model); set_param_model(&model);
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; opt->params = opt_params_adam;
opt->iter = model.train_its; opt->iter = model.train_its;
printf("%s: opt iter %d\n", __func__, opt->iter); printf("%s: opt iter %d\n", __func__, opt->iter);
@ -2563,7 +2534,7 @@ int main(int argc, char ** argv) {
size_t used_mem_after_opt = ggml_used_mem(ctx0); size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter; int n_iter = params.adam_n_iter;
model.train_its = opt->iter; model.train_its = opt->iter;
model.train_samples += n_batch * n_iter; model.train_samples += n_batch * n_iter;
model.train_tokens += n_batch * n_tokens * n_iter; model.train_tokens += n_batch * n_tokens * n_iter;