diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index 418cc5fff..ecdb418bf 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -2150,66 +2150,341 @@ float cosine_decay_restart(int decay_steps, const float alpha, int step, float r return cosine_decay(decay_steps, alpha, step); } -int main(int argc, char ** argv) { - const char * default_model = "ggml-vic7b-uncensored-q4_0.bin"; - const char * default_train = "shakespeare.txt"; - const char * default_chkpt_in = "checkpoint.bin"; - const char * default_chkpt_out = "checkpoint.bin"; - const char * default_model_out = "ggml-checkpoint-f32.bin"; - const char * default_argv[6] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out, default_model_out}; +struct train_params { + const char * fn_vocab_model; + const char * fn_train_data; + const char * fn_checkpoint_in; + const char * fn_checkpoint_out; + const char * fn_model_out; - if (argc < 6) { - fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out model_out\n", argv[0]); - //return 1; - } + int seed; + int n_ctx; + int n_embd; + int n_mult; + int n_head; + int n_layer; + int n_rotmax; - int seed = 1; - int n_ctx = 256; - // int n_ctx = 64; - int n_embd = 256; - int n_mult = 256; - int n_head = 8; - int n_layer = 16; - int n_rotmax = 64; - - int n_threads = 6; - int n_batch = 8; - int n_examples = 32; + int n_threads; + int n_batch; + int n_examples; + int n_predict; - int print_info_interval = 1; - int print_details_interval = 2; + int print_info_interval; + int print_details_interval; - bool samples_start_after_nl = false; - bool use_adam = true; - bool use_flash = false; + bool samples_start_after_nl; + bool use_adam; + bool use_flash; // only adam - int warmup = 100; - int cos_decay_steps = 1000; - float cos_decay_restart = 1.1f; - float cos_decay_alpha = 0.0f; - - int lbfgs_n_iter = 16; - int adam_n_iter = 16; - float adam_alpha = 1e-3; - float adam_decay = 1e-3; + int warmup; + int cos_decay_steps; + float cos_decay_restart; + float cos_decay_alpha; - if (seed < 0) { - srand(time(NULL)); - } else { - srand(seed); + int lbfgs_n_iter; + int adam_n_iter; + float adam_alpha; + float adam_decay; +}; + +struct train_params get_default_train_params() { + struct train_params params; + params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin"; + params.fn_train_data = "shakespeare.txt"; + params.fn_checkpoint_in = "checkpoint.bin"; + params.fn_checkpoint_out = "checkpoint.bin"; + params.fn_model_out = "ggml-checkpoint-f32.bin"; + + params.seed = -1; + + params.n_ctx = 128; + params.n_embd = 256; + params.n_mult = 256; + params.n_head = 8; + params.n_layer = 16; + params.n_rotmax = 64; + + params.n_threads = 6; + params.n_batch = 8; + params.n_examples = 8; + params.n_predict = 1024; + + params.print_info_interval = 1; + params.print_details_interval = 2; + + params.samples_start_after_nl = false; + params.use_adam = true; + params.use_flash = true; + + // only adam + params.warmup = 100; + params.cos_decay_steps = 1000; + params.cos_decay_restart = 1.1f; + params.cos_decay_alpha = 0.0f; + + params.lbfgs_n_iter = 16; + params.adam_n_iter = 16; + params.adam_alpha = 1e-3; + params.adam_decay = 1e-3; + + return params; +} + +void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model); + fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); + fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); + fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); + fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); + fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); + fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd); + fprintf(stderr, " --mult N Mult size used for new models, influences feedforward size. (default %d)\n", params->n_mult); + fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head); + fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer); + fprintf(stderr, " --rotmax N Maximal number Rope dimensions for new models (default %d)\n", params->n_rotmax); + fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); + fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); + fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); + fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict); + fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); + fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval); + fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); + fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); + fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); + fprintf(stderr, " --no-flash Don't use flash attention.\n"); + fprintf(stderr, " --use-flash Use flash attention (default)\n"); + fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup); + fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps); + fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); + fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha); + fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter); + fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); + fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); + fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, "\n"); +} + +bool train_params_parse(int argc, char ** argv, struct train_params * params) { + bool invalid_param = false; + std::string arg; + struct train_params default_params = get_default_train_params(); + const std::string arg_prefix = "--"; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + if (arg == "--vocab-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_vocab_model = argv[i]; + } else if (arg == "--train-data") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_train_data = argv[i]; + } else if (arg == "--checkpoint-in") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_checkpoint_in = argv[i]; + } else if (arg == "--checkpoint-out") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_checkpoint_out = argv[i]; + } else if (arg == "--model-out") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_model_out = argv[i]; + } else if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->seed = std::stoi(argv[i]); + } else if (arg == "-c" || arg == "--ctx") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_ctx = std::stoi(argv[i]); + } else if (arg == "--embd") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_embd = std::stoi(argv[i]); + } else if (arg == "--mult") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_mult = std::stoi(argv[i]); + } else if (arg == "--head") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_head = std::stoi(argv[i]); + } else if (arg == "--layer") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_layer = std::stoi(argv[i]); + } else if (arg == "--rotmax") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_rotmax = std::stoi(argv[i]); + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_threads = std::stoi(argv[i]); + } else if (arg == "-b" || arg == "--batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_batch = std::stoi(argv[i]); + } else if (arg == "-n" || arg == "--examples") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_examples = std::stoi(argv[i]); + } else if (arg == "--predict") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_predict = std::stoi(argv[i]); + } else if (arg == "--print-info-interval") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->print_info_interval = std::stoi(argv[i]); + } else if (arg == "--print-details-interval") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->print_details_interval = std::stoi(argv[i]); + } else if (arg == "--samples-after-nl") { + params->samples_start_after_nl = true; + } else if (arg == "--use-lbfgs") { + params->use_adam = false; + } else if (arg == "--use-adam") { + params->use_adam = true; + } else if (arg == "--no-flash") { + params->use_flash = false; + } else if (arg == "--use-flash") { + params->use_flash = true; + } else if (arg == "--warmup") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->warmup = std::stoi(argv[i]); + } else if (arg == "--cos-decay-steps") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->cos_decay_steps = std::stof(argv[i]); + } else if (arg == "--cos-decay-restart") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->cos_decay_restart = std::stof(argv[i]); + } else if (arg == "--cos-decay-alpha") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->cos_decay_alpha = std::stof(argv[i]); + } else if (arg == "--lbfgs-iter") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->lbfgs_n_iter = std::stoi(argv[i]); + } else if (arg == "--adam-iter") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_n_iter = std::stoi(argv[i]); + } else if (arg == "--adam-alpha") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_alpha = std::stof(argv[i]); + } else if (arg == "--adam-decay") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_decay = std::stof(argv[i]); + } else if (arg == "-h" || arg == "--help") { + train_print_usage(argc, argv, &default_params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + train_print_usage(argc, argv, &default_params); + exit(1); + } + } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + train_print_usage(argc, argv, &default_params); + exit(1); } - const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; - const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; - const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3]; - const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4]; - const char * fn_model_out = (argc >= 6) ? argv[5] : default_argv[5]; + return true; +} + +int main(int argc, char ** argv) { + struct train_params params = get_default_train_params(); + + if (!train_params_parse(argc, argv, ¶ms)) { + return 1; + } + + + if (params.seed < 0) { + srand(time(NULL)); + } else { + srand(params.seed); + } struct llama_context_params llama_params = llama_context_default_params(); llama_params.vocab_only = true; - struct llama_context * lctx = llama_init_from_file(fn_model, llama_params); + struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params); struct llama_vocab vocab; { @@ -2232,19 +2507,19 @@ int main(int argc, char ** argv) { printf("%s: tokenize training data\n", __func__); std::vector train_tokens; - if (tokenize_file(lctx, fn_train, train_tokens) < 0) { - fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, fn_train); + if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) { + fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data); } printf("%s: number of training tokens: %d\n", __func__, train_tokens.size()); struct my_llama_model model; model.hparams.n_vocab = llama_n_vocab(lctx); - model.hparams.n_ctx = n_ctx; - model.hparams.n_embd = n_embd; - model.hparams.n_mult = n_mult; - model.hparams.n_head = n_head; - model.hparams.n_layer = n_layer; - model.hparams.n_rot = std::min((uint32_t)n_rotmax, model.hparams.n_embd / model.hparams.n_head); + model.hparams.n_ctx = params.n_ctx; + model.hparams.n_embd = params.n_embd; + model.hparams.n_mult = params.n_mult; + model.hparams.n_head = params.n_head; + model.hparams.n_layer = params.n_layer; + model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); print_params(&model.hparams); @@ -2282,6 +2557,7 @@ int main(int argc, char ** argv) { int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; + int n_batch = params.n_batch; struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); memset(opt, 0, sizeof(struct ggml_opt_context)); @@ -2290,32 +2566,32 @@ int main(int argc, char ** argv) { struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); opt_params_adam.print_forward_graph = false; opt_params_adam.print_backward_graph = false; - opt_params_adam.n_threads = n_threads; - opt_params_adam.adam.n_iter = adam_n_iter; - opt_params_adam.adam.sched = 1.0f; - opt_params_adam.adam.alpha = adam_alpha; - opt_params_adam.adam.decay = adam_decay; + opt_params_adam.n_threads = params.n_threads; + opt_params_adam.adam.n_iter = params.adam_n_iter; + opt_params_adam.adam.sched = 1.0f; + opt_params_adam.adam.alpha = params.adam_alpha; + opt_params_adam.adam.decay = params.adam_decay; opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_backward_graph = false; - opt_params_lbfgs.n_threads = n_threads; - opt_params_lbfgs.lbfgs.n_iter = lbfgs_n_iter; + opt_params_lbfgs.n_threads = params.n_threads; + opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter; opt->ctx = model.ctx; - opt->params = use_adam ? opt_params_adam : opt_params_lbfgs; + opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; printf("%s: init model\n", __func__); - bool existed = load_checkpoint(&model, opt, fn_chkpt_in, true); + bool existed = load_checkpoint(&model, opt, params.fn_checkpoint_in, true); set_param_model(&model); - opt->params = use_adam ? opt_params_adam : opt_params_lbfgs; + opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; opt->iter = model.train_its; printf("%s: opt iter %d\n", __func__, opt->iter); bool from_scratch = !existed; if (from_scratch) { - randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f); + randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); } init_kv_cache(&kv_self, &model, 1); @@ -2328,11 +2604,11 @@ int main(int argc, char ** argv) { size_t compute_size = 1024ll*1024ll*1024ll*32ll; uint8_t * compute_addr = new uint8_t[compute_size]; - + GGML_ASSERT(train_tokens.size() > n_tokens);; std::vector train_samples; train_samples.push_back(0); for (int i=1; i= train_samples.size()) { shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size()); for (int i=0; iparams.adam.sched = (opt->iter < warmup) - ? (float) opt->iter / (float) warmup - : cosine_decay_restart(cos_decay_steps, cos_decay_alpha, opt->iter - warmup, cos_decay_restart); + opt->params.adam.sched = (opt->iter < params.warmup) + ? (float) opt->iter / (float) params.warmup + : cosine_decay_restart( + params.cos_decay_steps, + params.cos_decay_alpha, + opt->iter - params.warmup, + params.cos_decay_restart); + printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); // ggml_opt(ctx0, opt->params, e); @@ -2406,8 +2687,7 @@ int main(int argc, char ** argv) { float error_after_opt = ggml_get_f32_1d(e, 0); - - if (ex % print_info_interval == 0) { + if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); printf("error_before_opt: %.6f\n", error_before_opt); printf("error_after_opt: %.6f\n", error_after_opt); @@ -2415,7 +2695,7 @@ int main(int argc, char ** argv) { printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); } - if (ex % print_details_interval == 0) { + if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) { // set_logits_masked(logits, token_notavail, -1e9); for (int i=0; i 0) { - save_checkpoint(&model, opt, fn_chkpt_out); + if (params.n_examples > 0) { + save_checkpoint(&model, opt, params.fn_checkpoint_out); } - if (strlen(fn_model_out) > 0) { - save_as_llama_model(&vocab, &model, fn_model_out); + if (strlen(params.fn_model_out) > 0) { + save_as_llama_model(&vocab, &model, params.fn_model_out); } { - int n_gen = 1024; + int n_gen = params.n_predict; int sample_ctx = n_tokens - n_tokens/8; sampler.params.temp = 0.2; @@ -2477,15 +2757,15 @@ int main(int argc, char ** argv) { printf("---\n"); for (int i=0; i