diff --git a/common/train.cpp b/common/train.cpp index 81039e5eb..d22d4b036 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -1006,3 +1006,326 @@ std::string get_train_filename(const char * filename, const char * pattern_it, c std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); return replace_str(filename, pattern_it, sit.c_str()); } + +struct train_params_common get_default_train_params_common() { + struct train_params_common params; + params.fn_train_data = "shakespeare.txt"; + params.fn_checkpoint_in = "checkpoint.gguf"; + params.fn_checkpoint_out = "checkpoint-ITERATION.gguf"; + params.pattern_fn_it = "ITERATION"; + params.fn_latest = "LATEST"; + + params.print_usage = false; + + params.save_every = 10; + + params.seed = -1; + + params.n_ctx = 128; + params.n_threads = 6; + params.n_batch = 8; + params.n_gradient_accumulation = 1; + + params.custom_n_ctx = false; + + params.use_flash = true; + params.use_checkpointing = true; + + params.sample_start = ""; + params.include_sample_start = false; + params.escape = false; + params.overlapping_samples = false; + params.fill_with_next_samples = false; + params.separate_with_eos = false; + params.separate_with_bos = true; + params.force_reshuffle = false; + + params.opt_past = 0; + params.opt_delta = 1e-5f; + params.opt_max_no_improvement = 0; + + params.warmup = 100; + params.cos_decay_steps = 1000; + params.cos_decay_restart = 1.1f; + params.cos_decay_min = 0.1f; + params.enable_restart = false; + + params.adam_n_iter = 256; + params.adam_alpha = 1e-3f; + params.adam_min_alpha = 0; + params.adam_decay = 1e-1f; + params.adam_decay_min_ndim = 2; + params.adam_beta1 = 0.9f; + params.adam_beta2 = 0.999f; + params.adam_gclip = 1.0f; + params.adam_eps_f = 0.0f; + return params; +} + +void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params) { + // fprintf(stderr, "usage: %s [options]\n", argv[0]); + // fprintf(stderr, "\n"); + // fprintf(stderr, "options:\n"); + // fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); + fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); + fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); + fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); + fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest); + fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); + fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); + fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); + fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); + fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); + fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); + fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); + fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); + fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); + fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); + fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); + fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); + fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); + fprintf(stderr, " --no-flash Don't use flash attention \n"); + fprintf(stderr, " --use-flash Use flash attention (default)\n"); + fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); + fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n"); + fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); + fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); + fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); + fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min); + fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : ""); + fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : ""); + fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); + fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); + fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); + fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); + fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); + fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); + fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); + fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); + fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); + fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); + fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); + fprintf(stderr, "\n"); +} + +bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param) { + int& i = *idx; + char * arg = argv[i]; + if (arg == "--train-data") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_train_data = argv[i]; + } else if (arg == "--checkpoint-in") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_checkpoint_in = argv[i]; + } else if (arg == "--checkpoint-out") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_checkpoint_out = argv[i]; + } else if (arg == "--pattern-fn-it") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->pattern_fn_it = argv[i]; + } else if (arg == "--fn-latest") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_latest = argv[i]; + } else if (arg == "--save-every") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->save_every = std::stoi(argv[i]); + } else if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->seed = std::stoi(argv[i]); + } else if (arg == "-c" || arg == "--ctx") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_ctx = std::stoi(argv[i]); + params->custom_n_ctx = true; + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_threads = std::stoi(argv[i]); + } else if (arg == "-b" || arg == "--batch") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_batch = std::stoi(argv[i]); + } else if (arg == "--grad-acc") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); + } else if (arg == "--sample-start") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->sample_start = std::string(argv[i]); + } else if (arg == "--escape") { + params->escape = true; + } else if (arg == "--include-sample-start") { + params->include_sample_start = true; + } else if (arg == "--overlapping-samples") { + params->overlapping_samples = true; + } else if (arg == "--fill-with-next-samples") { + params->fill_with_next_samples = true; + } else if (arg == "--separate-with-eos") { + params->separate_with_eos = true; + } else if (arg == "--separate-with-bos") { + params->separate_with_bos = true; + } else if (arg == "--no-separate-with-eos") { + params->separate_with_eos = false; + } else if (arg == "--no-separate-with-bos") { + params->separate_with_bos = false; + } else if (arg == "--force-reshuffle") { + params->force_reshuffle = true; + } else if (arg == "--no-flash") { + params->use_flash = false; + } else if (arg == "--use-flash") { + params->use_flash = true; + } else if (arg == "--no-checkpointing") { + params->use_checkpointing = false; + } else if (arg == "--use-checkpointing") { + params->use_checkpointing = true; + } else if (arg == "--warmup") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->warmup = std::stoi(argv[i]); + } else if (arg == "--cos-decay-steps") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->cos_decay_steps = std::stoi(argv[i]); + } else if (arg == "--cos-decay-restart") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->cos_decay_restart = std::stof(argv[i]); + } else if (arg == "--cos-decay-min") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->cos_decay_min = std::stof(argv[i]); + } else if (arg == "--enable-restart") { + params->enable_restart = true; + } else if (arg == "--disable-restart") { + params->enable_restart = false; + } else if (arg == "--opt-past") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->opt_past = std::stoi(argv[i]); + } else if (arg == "--opt-delta") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->opt_delta = std::stof(argv[i]); + } else if (arg == "--opt-max-no-improvement") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->opt_max_no_improvement = std::stoi(argv[i]); + } else if (arg == "--adam-epsf") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_eps_f = std::stof(argv[i]); + } else if (arg == "--adam-iter") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_n_iter = std::stoi(argv[i]); + } else if (arg == "--adam-alpha") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_alpha = std::stof(argv[i]); + } else if (arg == "--adam-min-alpha") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_min_alpha = std::stof(argv[i]); + } else if (arg == "--adam-decay") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_decay = std::stof(argv[i]); + } else if (arg == "--adam-decay-min-ndim") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_decay_min_ndim = std::stoi(argv[i]); + } else if (arg == "--adam-beta1") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_beta1 = std::stof(argv[i]); + } else if (arg == "--adam-beta2") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_beta2 = std::stof(argv[i]); + } else if (arg == "--adam-gclip") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_gclip = std::stof(argv[i]); + } else if (arg == "-h" || arg == "--help") { + params->print_usage = true; + return true; + } else { + return false; + } + return true; +} + +void finish_processing_train_args(struct train_params_common * params) { + if (params->escape) { + process_escapes(params->sample_start); + } +} diff --git a/common/train.h b/common/train.h index 59004a87c..cc3673c36 100644 --- a/common/train.h +++ b/common/train.h @@ -26,9 +26,69 @@ struct train_state { size_t shuffle_next_sample; }; +struct train_params_common { + const char * fn_train_data; + const char * fn_checkpoint_in; + const char * fn_checkpoint_out; + const char * pattern_fn_it; + const char * fn_latest; + + bool print_usage; + + int save_every; + + uint32_t seed; + + int n_ctx; + int n_threads; + int n_batch; + int n_gradient_accumulation; + + bool custom_n_ctx; + + bool use_flash; + bool use_checkpointing; + + std::string sample_start; + bool include_sample_start; + bool escape; + bool overlapping_samples; + bool fill_with_next_samples; + bool separate_with_eos; + bool separate_with_bos; + + bool force_reshuffle; + + int warmup; + int cos_decay_steps; + float cos_decay_restart; + float cos_decay_min; + bool enable_restart; + + int opt_past; + float opt_delta; + int opt_max_no_improvement; + + int adam_n_iter; + float adam_alpha; + float adam_min_alpha; + float adam_decay; + int adam_decay_min_ndim; + float adam_beta1; + float adam_beta2; + float adam_gclip; + float adam_eps_f; +}; + struct train_state * init_train_state(int seed); void free_train_state(struct train_state * state); +struct train_params_common get_default_train_params_common(); +void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params); + +bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param); +void finish_processing_train_args(struct train_params_common * params); + struct random_normal_distribution; struct random_uniform_distribution; diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 5c787e94e..09a29340a 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1197,24 +1197,10 @@ static void save_as_llama_lora(const char * filename, struct my_llama_lora * lor } struct train_params { + struct train_params_common common; + const char * fn_model_base; - const char * fn_train_data; - const char * fn_checkpoint_in; - const char * fn_checkpoint_out; const char * fn_lora_out; - const char * pattern_fn_it; - const char * fn_latest; - - int save_every; - - uint32_t seed; - - int n_ctx; - int n_threads; - int n_batch; - int n_gradient_accumulation; - - bool custom_n_ctx; bool only_write_lora; @@ -1255,61 +1241,13 @@ struct train_params { bool custom_n_rank_tok_embeddings; bool custom_n_rank_norm; bool custom_n_rank_output; - - bool use_flash; - bool use_checkpointing; - - std::string sample_start; - bool include_sample_start; - bool escape; - bool overlapping_samples; - bool fill_with_next_samples; - bool separate_with_eos; - bool separate_with_bos; - - bool force_reshuffle; - - int warmup; - int cos_decay_steps; - float cos_decay_restart; - float cos_decay_min; - bool enable_restart; - - int opt_past; - float opt_delta; - int opt_max_no_improvement; - - int adam_n_iter; - float adam_alpha; - float adam_min_alpha; - float adam_decay; - int adam_decay_min_ndim; - float adam_beta1; - float adam_beta2; - float adam_gclip; - float adam_eps_f; }; static struct train_params get_default_train_params() { struct train_params params; + params.common = get_default_train_params_common(); params.fn_model_base = ""; - params.fn_train_data = "shakespeare.txt"; - params.fn_checkpoint_in = "checkpoint.gguf"; - params.fn_checkpoint_out = "checkpoint-ITERATION.gguf"; params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf"; - params.pattern_fn_it = "ITERATION"; - params.fn_latest = "LATEST"; - - params.save_every = 10; - - params.seed = -1; - - params.n_ctx = 128; - params.n_threads = 6; - params.n_batch = 8; - params.n_gradient_accumulation = 1; - - params.custom_n_ctx = false; params.only_write_lora = false; @@ -1351,59 +1289,18 @@ static struct train_params get_default_train_params() { params.custom_n_rank_norm = false; params.custom_n_rank_output = false; - params.use_flash = true; - params.use_checkpointing = true; - - params.sample_start = ""; - params.include_sample_start = false; - params.escape = false; - params.overlapping_samples = false; - params.fill_with_next_samples = false; - params.separate_with_eos = false; - params.separate_with_bos = true; - params.force_reshuffle = false; - - params.opt_past = 0; - params.opt_delta = 1e-5f; - params.opt_max_no_improvement = 0; - - params.warmup = 100; - params.cos_decay_steps = 1000; - params.cos_decay_restart = 1.1f; - params.cos_decay_min = 0.1f; - params.enable_restart = false; - - params.adam_n_iter = 256; - params.adam_alpha = 1e-3f; - params.adam_min_alpha = 0; - params.adam_decay = 1e-1f; - params.adam_decay_min_ndim = 2; - params.adam_beta1 = 0.9f; - params.adam_beta2 = 0.999f; - params.adam_gclip = 1.0f; - params.adam_eps_f = 0.0f; return params; } -static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) { +static void train_print_usage(int argc, char ** argv, const struct train_params * params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base); - fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); - fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); - fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out); - fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); - fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest); - fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every); fprintf(stderr, " --only-write-lora only save llama lora, don't do any training\n"); - fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); - fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); - fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); - fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); - fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); @@ -1421,39 +1318,8 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n"); fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n"); fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n"); - fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); - fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); - fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); - fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); - fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); - fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); - fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); - fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); - fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); - fprintf(stderr, " --no-flash Don't use flash attention \n"); - fprintf(stderr, " --use-flash Use flash attention (default)\n"); - fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); - fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n"); - fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); - fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); - fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); - fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min); - fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : ""); - fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : ""); - fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); - fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); - fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); - fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); - fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); - fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); - fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); - fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); - fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); - fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); - fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); - fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); - fprintf(stderr, "\n"); + + print_common_train_usage(argc, argv, ¶ms->common); } static bool train_params_parse(int argc, char ** argv, struct train_params * params) { @@ -1468,87 +1334,27 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par std::replace(arg.begin(), arg.end(), '_', '-'); } - if (arg == "--model-base") { + if (consume_common_train_arg(argc, argv, &i, ¶ms->common, &invalid_param)) { + if (invalid_param) { + break; + } else if (params->common.print_usage) { + train_print_usage(argc, argv, &default_params); + exit(0); + } + } else if (arg == "--model-base") { if (++i >= argc) { invalid_param = true; break; } params->fn_model_base = argv[i]; - } else if (arg == "--train-data") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_train_data = argv[i]; - } else if (arg == "--checkpoint-in") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_checkpoint_in = argv[i]; - } else if (arg == "--checkpoint-out") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_checkpoint_out = argv[i]; } else if (arg == "--lora-out") { if (++i >= argc) { invalid_param = true; break; } params->fn_lora_out = argv[i]; - } else if (arg == "--pattern-fn-it") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->pattern_fn_it = argv[i]; - } else if (arg == "--fn-latest") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_latest = argv[i]; - } else if (arg == "--save-every") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->save_every = std::stoi(argv[i]); } else if (arg == "--only-write-lora") { params->only_write_lora = true; - } else if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->seed = std::stoi(argv[i]); - } else if (arg == "-c" || arg == "--ctx") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_ctx = std::stoi(argv[i]); - params->custom_n_ctx = true; - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_threads = std::stoi(argv[i]); - } else if (arg == "-b" || arg == "--batch") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_batch = std::stoi(argv[i]); - } else if (arg == "--grad-acc") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); } else if (arg == "--norm-rms-eps") { if (++i >= argc) { invalid_param = true; @@ -1667,141 +1473,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par } params->n_rank_w3 = std::stoi(argv[i]); params->custom_n_rank_w3 = true; - } else if (arg == "--sample-start") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->sample_start = std::string(argv[i]); - } else if (arg == "--escape") { - params->escape = true; - } else if (arg == "--include-sample-start") { - params->include_sample_start = true; - } else if (arg == "--overlapping-samples") { - params->overlapping_samples = true; - } else if (arg == "--fill-with-next-samples") { - params->fill_with_next_samples = true; - } else if (arg == "--separate-with-eos") { - params->separate_with_eos = true; - } else if (arg == "--separate-with-bos") { - params->separate_with_bos = true; - } else if (arg == "--no-separate-with-eos") { - params->separate_with_eos = false; - } else if (arg == "--no-separate-with-bos") { - params->separate_with_bos = false; - } else if (arg == "--force-reshuffle") { - params->force_reshuffle = true; - } else if (arg == "--no-flash") { - params->use_flash = false; - } else if (arg == "--use-flash") { - params->use_flash = true; - } else if (arg == "--no-checkpointing") { - params->use_checkpointing = false; - } else if (arg == "--use-checkpointing") { - params->use_checkpointing = true; - } else if (arg == "--warmup") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->warmup = std::stoi(argv[i]); - } else if (arg == "--cos-decay-steps") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->cos_decay_steps = std::stof(argv[i]); - } else if (arg == "--cos-decay-restart") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->cos_decay_restart = std::stof(argv[i]); - } else if (arg == "--cos-decay-min") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->cos_decay_min = std::stof(argv[i]); - } else if (arg == "--enable-restart") { - params->enable_restart = true; - } else if (arg == "--disable-restart") { - params->enable_restart = false; - } else if (arg == "--opt-past") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->opt_past = std::stoi(argv[i]); - } else if (arg == "--opt-delta") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->opt_delta = std::stof(argv[i]); - } else if (arg == "--opt-max-no-improvement") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->opt_max_no_improvement = std::stoi(argv[i]); - } else if (arg == "--adam-epsf") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_eps_f = std::stof(argv[i]); - } else if (arg == "--adam-iter") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_n_iter = std::stoi(argv[i]); - } else if (arg == "--adam-alpha") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_alpha = std::stof(argv[i]); - } else if (arg == "--adam-min-alpha") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_min_alpha = std::stof(argv[i]); - } else if (arg == "--adam-decay") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_decay = std::stof(argv[i]); - } else if (arg == "--adam-decay-min-ndim") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_decay_min_ndim = std::stoi(argv[i]); - } else if (arg == "--adam-beta1") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_beta1 = std::stof(argv[i]); - } else if (arg == "--adam-beta2") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_beta2 = std::stof(argv[i]); - } else if (arg == "--adam-gclip") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_gclip = std::stof(argv[i]); - } else if (arg == "-h" || arg == "--help") { - train_print_usage(argc, argv, &default_params); - exit(0); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); train_print_usage(argc, argv, &default_params); @@ -1813,9 +1484,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par train_print_usage(argc, argv, &default_params); exit(1); } - if (params->escape) { - process_escapes(params->sample_start); - } + finish_processing_train_args(¶ms->common); return true; } @@ -1844,31 +1513,31 @@ static void save_train_files(void * vdata, struct train_state * train) { } struct opt_callback_data { - struct train_params * params; - struct train_state * train; - save_train_files_callback save_cb; - void * save_data; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_probs; - int first_iter; - int64_t last_time; - double millis_per_iter; + struct train_params_common * params; + struct train_state * train; + save_train_files_callback save_cb; + void * save_data; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; static void opt_callback(void * vdata, int accum_step, float * sched) { - struct opt_callback_data * data = (struct opt_callback_data *) vdata; - struct train_params * params = data->params; - struct train_state * train = data->train; - struct ggml_opt_context * opt = train->opt; + struct opt_callback_data * data = (struct opt_callback_data *) vdata; + struct train_params_common * params = data->params; + struct train_state * train = data->train; + struct ggml_opt_context * opt = train->opt; int n_batch = params->n_batch; int n_ctx = params->n_ctx; @@ -2019,11 +1688,11 @@ int main(int argc, char ** argv) { return 1; } - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); + if (params.common.seed == LLAMA_DEFAULT_SEED) { + params.common.seed = time(NULL); } - printf("%s: seed: %u\n", __func__, params.seed); - srand(params.seed); + printf("%s: seed: %u\n", __func__, params.common.seed); + srand(params.common.seed); struct llama_context_params llama_params = llama_context_default_params(); llama_params.vocab_only = false; @@ -2033,11 +1702,11 @@ int main(int argc, char ** argv) { struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); struct my_llama_model model; - init_model(lmodel, &model, params.n_ctx); + init_model(lmodel, &model, params.common.n_ctx); struct my_llama_lora lora; - struct train_state * train = init_train_state(params.seed); + struct train_state * train = init_train_state(params.common.seed); struct ggml_opt_context * opt = train->opt; load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams); @@ -2083,30 +1752,30 @@ int main(int argc, char ** argv) { opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params.print_forward_graph = false; opt->params.print_backward_graph = false; - opt->params.n_threads = params.n_threads; - opt->params.past = params.opt_past; - opt->params.delta = params.opt_delta; - opt->params.max_no_improvement = params.opt_max_no_improvement; - opt->params.n_gradient_accumulation = params.n_gradient_accumulation; - opt->params.adam.n_iter = params.adam_n_iter; + opt->params.n_threads = params.common.n_threads; + opt->params.past = params.common.opt_past; + opt->params.delta = params.common.opt_delta; + opt->params.max_no_improvement = params.common.opt_max_no_improvement; + opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation; + opt->params.adam.n_iter = params.common.adam_n_iter; opt->params.adam.sched = 1.0f; - opt->params.adam.alpha = params.adam_alpha; - opt->params.adam.decay = params.adam_decay; - opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim; - opt->params.adam.beta1 = params.adam_beta1; - opt->params.adam.beta2 = params.adam_beta2; - opt->params.adam.gclip = params.adam_gclip; - opt->params.adam.eps_f = params.adam_eps_f; + opt->params.adam.alpha = params.common.adam_alpha; + opt->params.adam.decay = params.common.adam_decay; + opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim; + opt->params.adam.beta1 = params.common.adam_beta1; + opt->params.adam.beta2 = params.common.adam_beta2; + opt->params.adam.gclip = params.common.adam_gclip; + opt->params.adam.eps_f = params.common.adam_eps_f; ggml_allocr * alloc = NULL; printf("%s: init model\n", __func__); - bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, train); + bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train); if (existed) { // overwrite last n_ctx with user provided n_ctx - if (params.custom_n_ctx) { - model.hparams.n_ctx = params.n_ctx; + if (params.common.custom_n_ctx) { + model.hparams.n_ctx = params.common.n_ctx; } const bool opt_param_count_changed = ( @@ -2124,7 +1793,7 @@ int main(int argc, char ** argv) { || (lora.hparams.n_rank_output != n_rank_output) ); - const bool opt_past_changed = opt->params.past != params.opt_past; + const bool opt_past_changed = opt->params.past != params.common.opt_past; if (opt_param_count_changed) { print_lora_params(&lora.hparams); @@ -2139,7 +1808,7 @@ int main(int argc, char ** argv) { } } else { // existed == false init_lora(&model, &lora); - randomize_lora(&lora, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); + randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f); if (!params.only_write_lora) { ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora)); } @@ -2159,8 +1828,8 @@ int main(int argc, char ** argv) { save_train_files_data save_data; save_data.fn_checkpoint_out = ""; save_data.fn_lora_out = params.fn_lora_out; - save_data.pattern_fn_it = params.pattern_fn_it; - save_data.fn_latest = params.fn_latest; + save_data.pattern_fn_it = params.common.pattern_fn_it; + save_data.fn_latest = params.common.fn_latest; save_data.model = &model; save_data.lora = &lora; @@ -2175,7 +1844,7 @@ int main(int argc, char ** argv) { int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; - int n_batch = params.n_batch; + int n_batch = params.common.n_batch; printf("%s: opt iter %d\n", __func__, opt->iter); @@ -2215,7 +1884,7 @@ int main(int argc, char ** argv) { size_t estimated_compute_size_wo_data = ( ggml_tensor_overhead()*GGML_MAX_NODES*2 + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*( - params.use_checkpointing ? 3 : 2 + params.common.use_checkpointing ? 3 : 2 ) ); struct ggml_init_params ctx_compute_params = { @@ -2242,7 +1911,7 @@ int main(int argc, char ** argv) { gf = ggml_new_graph(ctx_compute); gf->order = (enum ggml_cgraph_eval_order) order; gb = ggml_new_graph(ctx_compute); - gb_tmp = params.use_checkpointing + gb_tmp = params.common.use_checkpointing ? ggml_new_graph(ctx_compute) : NULL; loss = llama_build_lora_finetune_graphs( @@ -2250,8 +1919,8 @@ int main(int argc, char ** argv) { gf, gb, gb_tmp, &logits, tokens_input, target_probs, n_tokens, n_batch, - params.use_flash, - params.use_checkpointing + params.common.use_flash, + params.common.use_checkpointing ); size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment; if (max_compute_size < best_compute_size) { @@ -2275,7 +1944,7 @@ int main(int argc, char ** argv) { gf = ggml_new_graph(ctx_compute); gf->order = best_order; gb = ggml_new_graph(ctx_compute); - gb_tmp = params.use_checkpointing + gb_tmp = params.common.use_checkpointing ? ggml_new_graph(ctx_compute) : NULL; loss = llama_build_lora_finetune_graphs( @@ -2283,8 +1952,8 @@ int main(int argc, char ** argv) { gf, gb, gb_tmp, &logits, tokens_input, target_probs, n_tokens, n_batch, - params.use_flash, - params.use_checkpointing + params.common.use_flash, + params.common.use_checkpointing ); ggml_allocr_free(alloc); @@ -2294,10 +1963,10 @@ int main(int argc, char ** argv) { std::vector train_samples_size; printf("%s: tokenize training data\n", __func__); tokenize_file(lctx, - params.fn_train_data, - params.sample_start, - params.include_sample_start, - params.overlapping_samples, + params.common.fn_train_data, + params.common.sample_start, + params.common.include_sample_start, + params.common.overlapping_samples, n_tokens, train_tokens, train_samples_begin, @@ -2318,16 +1987,16 @@ int main(int argc, char ** argv) { } printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); - size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); + size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size()); if (changed_train_data) { printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); } - if (params.force_reshuffle) { + if (params.common.force_reshuffle) { printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); } - if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { - train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed); + if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) { + train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed); train->shuffle_sample_count = train_samples_size.size(); train->shuffle_next_sample = 0; train->shuffle_samples_hash = shuffle_samples_hash; @@ -2347,15 +2016,15 @@ int main(int argc, char ** argv) { printf("%s: begin training\n", __func__); save_train_files_data save_data; - save_data.fn_checkpoint_out = params.fn_checkpoint_out; + save_data.fn_checkpoint_out = params.common.fn_checkpoint_out; save_data.fn_lora_out = params.fn_lora_out; - save_data.pattern_fn_it = params.pattern_fn_it; - save_data.fn_latest = params.fn_latest; + save_data.pattern_fn_it = params.common.pattern_fn_it; + save_data.fn_latest = params.common.fn_latest; save_data.model = &model; save_data.lora = &lora; struct opt_callback_data opt_cb_data; - opt_cb_data.params = ¶ms; + opt_cb_data.params = ¶ms.common; opt_cb_data.train = train; opt_cb_data.save_cb = &save_train_files; opt_cb_data.save_data = &save_data; @@ -2375,7 +2044,7 @@ int main(int argc, char ** argv) { opt_cb_data.millis_per_iter = 0.0; // measure required memory for work buffer - size_t max_work_size = ggml_graph_plan(gb, params.n_threads).work_size + GGML_OBJECT_SIZE; + size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE; printf("%s: max_work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f)); // context for work buffer diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 7984dd724..5b993b47b 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -692,17 +692,10 @@ static void save_checkpoint_file(const char * filename, const char * fn_vocab_mo } struct train_params { + struct train_params_common common; + const char * fn_vocab_model; - const char * fn_train_data; - const char * fn_checkpoint_in; - const char * fn_checkpoint_out; const char * fn_model_out; - const char * pattern_fn_it; - const char * fn_latest; - - int save_every; - - uint32_t seed; int n_ctx; int n_embd; @@ -710,10 +703,7 @@ struct train_params { int n_layer; int n_ff; - int n_threads; int n_examples; - int n_batch; - int n_gradient_accumulation; float f_norm_rms_eps; float rope_freq_base; @@ -721,40 +711,8 @@ struct train_params { int print_info_interval; - bool use_flash; - bool use_checkpointing; bool use_alloc; - std::string sample_start; - bool include_sample_start; - bool escape; - bool overlapping_samples; - bool fill_with_next_samples; - bool separate_with_eos; - bool separate_with_bos; - - bool force_reshuffle; - - int warmup; - int cos_decay_steps; - float cos_decay_restart; - float cos_decay_min; - bool enable_restart; - - int opt_past; - float opt_delta; - int opt_max_no_improvement; - - int adam_n_iter; - float adam_alpha; - float adam_min_alpha; - float adam_decay; - int adam_decay_min_ndim; - float adam_beta1; - float adam_beta2; - float adam_gclip; - float adam_eps_f; - int mem_model_gb; int mem_compute_gb; int mem_compute0_gb; @@ -762,17 +720,9 @@ struct train_params { struct train_params get_default_train_params() { struct train_params params; + params.common = get_default_train_params_common(); params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin"; - params.fn_train_data = "shakespeare.txt"; - params.fn_checkpoint_in = "checkpoint.bin"; - params.fn_checkpoint_out = "checkpoint.bin"; params.fn_model_out = "ggml-checkpoint-f32.bin"; - params.pattern_fn_it = "ITERATION"; - params.fn_latest = "LATEST"; - - params.save_every = 10; - - params.seed = -1; params.n_ctx = 128; params.n_embd = 256; @@ -780,10 +730,7 @@ struct train_params get_default_train_params() { params.n_layer = 16; params.n_ff = 768; - params.n_threads = 6; params.n_examples = 1; - params.n_batch = 8; - params.n_gradient_accumulation = 1; params.f_norm_rms_eps = 1e-5f; params.rope_freq_base = 10000.0f; @@ -791,60 +738,22 @@ struct train_params get_default_train_params() { params.print_info_interval = 1; - params.use_flash = true; - params.use_checkpointing = true; params.use_alloc = true; - params.sample_start = ""; - params.include_sample_start = false; - params.escape = false; - params.overlapping_samples = false; - params.fill_with_next_samples = false; - params.separate_with_eos = false; - params.separate_with_bos = true; - params.force_reshuffle = false; - - params.opt_past = 0; - params.opt_delta = 1e-5f; - params.opt_max_no_improvement = 0; - - params.warmup = 100; - params.cos_decay_steps = 1000; - params.cos_decay_restart = 1.1f; - params.cos_decay_min = 0.1f; - params.enable_restart = false; - - params.adam_n_iter = 256; - params.adam_alpha = 1e-3f; - params.adam_min_alpha = 0; - params.adam_decay = 1e-1f; - params.adam_decay_min_ndim = 2; - params.adam_beta1 = 0.9f; - params.adam_beta2 = 0.999f; - params.adam_gclip = 1.0f; - params.adam_eps_f = 0.0f; - params.mem_model_gb = 2; params.mem_compute_gb = 24; params.mem_compute0_gb = 8; return params; } -static void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) { +static void train_print_usage(int argc, char ** argv, const struct train_params * params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model); - fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); - fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); - fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out); - fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); - fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest); - fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every); - fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); - fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd); fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff); fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head); @@ -852,49 +761,15 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); - fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); - fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); - fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); - fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); - fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); - fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); - fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); - fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); - fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); - fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); - fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); - fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); - fprintf(stderr, " --no-flash Don't use flash attention \n"); - fprintf(stderr, " --use-flash Use flash attention (default)\n"); - fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); - fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n"); fprintf(stderr, " --no-alloc Don't use allocator\n"); fprintf(stderr, " --use-alloc Use allocator (default)\n"); - fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); - fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); - fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); - fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min); - fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : ""); - fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : ""); - fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); - fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); - fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); - fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); - fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); - fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); - fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); - fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); - fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); - fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); - fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); - fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb); - fprintf(stderr, "\n"); + + print_common_train_usage(argc, argv, ¶ms->common); } static bool train_params_parse(int argc, char ** argv, struct train_params * params) { @@ -909,66 +784,25 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par std::replace(arg.begin(), arg.end(), '_', '-'); } - if (arg == "--vocab-model") { + if (consume_common_train_arg(argc, argv, &i, ¶ms->common, &invalid_param)) { + if (invalid_param) { + break; + } else if (params->common.print_usage) { + train_print_usage(argc, argv, &default_params); + exit(0); + } + } else if (arg == "--vocab-model") { if (++i >= argc) { invalid_param = true; break; } params->fn_vocab_model = argv[i]; - } else if (arg == "--train-data") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_train_data = argv[i]; - } else if (arg == "--checkpoint-in") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_checkpoint_in = argv[i]; - } else if (arg == "--checkpoint-out") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_checkpoint_out = argv[i]; } else if (arg == "--model-out") { if (++i >= argc) { invalid_param = true; break; } params->fn_model_out = argv[i]; - } else if (arg == "--pattern-fn-it") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->pattern_fn_it = argv[i]; - } else if (arg == "--fn-latest") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->fn_latest = argv[i]; - } else if (arg == "--save-every") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->save_every = std::stoi(argv[i]); - } else if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->seed = std::stoi(argv[i]); - } else if (arg == "-c" || arg == "--ctx") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_ctx = std::stoi(argv[i]); } else if (arg == "--embd") { if (++i >= argc) { invalid_param = true; @@ -1011,24 +845,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par break; } params->rope_freq_scale = std::stof(argv[i]); - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_threads = std::stoi(argv[i]); - } else if (arg == "-b" || arg == "--batch") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_batch = std::stoi(argv[i]); - } else if (arg == "--grad-acc") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); } else if (arg == "-n" || arg == "--examples") { if (++i >= argc) { invalid_param = true; @@ -1041,142 +857,10 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par break; } params->print_info_interval = std::stoi(argv[i]); - } else if (arg == "--sample-start") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->sample_start = std::string(argv[i]); - } else if (arg == "--escape") { - params->escape = true; - } else if (arg == "--include-sample-start") { - params->include_sample_start = true; - } else if (arg == "--overlapping-samples") { - params->overlapping_samples = true; - } else if (arg == "--fill-with-next-samples") { - params->fill_with_next_samples = true; - } else if (arg == "--separate-with-eos") { - params->separate_with_eos = true; - } else if (arg == "--separate-with-bos") { - params->separate_with_bos = true; - } else if (arg == "--no-separate-with-eos") { - params->separate_with_eos = false; - } else if (arg == "--no-separate-with-bos") { - params->separate_with_bos = false; - } else if (arg == "--force-reshuffle") { - params->force_reshuffle = true; - } else if (arg == "--no-flash") { - params->use_flash = false; - } else if (arg == "--use-flash") { - params->use_flash = true; - } else if (arg == "--no-checkpointing") { - params->use_checkpointing = false; - } else if (arg == "--use-checkpointing") { - params->use_checkpointing = true; } else if (arg == "--no-alloc") { params->use_alloc = false; } else if (arg == "--use-alloc") { params->use_alloc = true; - } else if (arg == "--warmup") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->warmup = std::stoi(argv[i]); - } else if (arg == "--cos-decay-steps") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->cos_decay_steps = std::stof(argv[i]); - } else if (arg == "--cos-decay-restart") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->cos_decay_restart = std::stof(argv[i]); - } else if (arg == "--cos-decay-min") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->cos_decay_min = std::stof(argv[i]); - } else if (arg == "--enable-restart") { - params->enable_restart = true; - } else if (arg == "--disable-restart") { - params->enable_restart = false; - } else if (arg == "--opt-past") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->opt_past = std::stoi(argv[i]); - } else if (arg == "--opt-delta") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->opt_delta = std::stof(argv[i]); - } else if (arg == "--opt-max-no-improvement") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->opt_max_no_improvement = std::stoi(argv[i]); - } else if (arg == "--adam-epsf") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_eps_f = std::stof(argv[i]); - } else if (arg == "--adam-iter") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_n_iter = std::stoi(argv[i]); - } else if (arg == "--adam-alpha") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_alpha = std::stof(argv[i]); - } else if (arg == "--adam-min-alpha") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_min_alpha = std::stof(argv[i]); - } else if (arg == "--adam-decay") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_decay = std::stof(argv[i]); - } else if (arg == "--adam-decay-min-ndim") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_decay_min_ndim = std::stoi(argv[i]); - } else if (arg == "--adam-beta1") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_beta1 = std::stof(argv[i]); - } else if (arg == "--adam-beta2") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_beta2 = std::stof(argv[i]); - } else if (arg == "--adam-gclip") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->adam_gclip = std::stof(argv[i]); } else if (arg == "--mem-model") { if (++i >= argc) { invalid_param = true; @@ -1195,9 +879,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par break; } params->mem_compute0_gb = std::stoi(argv[i]); - } else if (arg == "-h" || arg == "--help") { - train_print_usage(argc, argv, &default_params); - exit(0); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); train_print_usage(argc, argv, &default_params); @@ -1209,9 +890,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par train_print_usage(argc, argv, &default_params); exit(1); } - if (params->escape) { - process_escapes(params->sample_start); - } + finish_processing_train_args(¶ms->common); return true; } @@ -1241,32 +920,32 @@ static void save_train_files(void * vdata, struct train_state * train) { } struct opt_callback_data { - struct train_params * params; - struct train_state * train; - save_train_files_callback save_cb; - void * save_data; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_logits; - struct ggml_tensor * target_probs; - int first_iter; - int64_t last_time; - double millis_per_iter; + struct train_params_common * params; + struct train_state * train; + save_train_files_callback save_cb; + void * save_data; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_logits; + struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; static void opt_callback(void * vdata, int accum_step, float * sched) { - struct opt_callback_data * data = (struct opt_callback_data *) vdata; - struct train_params * params = data->params; - struct train_state * train = data->train; - struct ggml_opt_context * opt = train->opt; + struct opt_callback_data * data = (struct opt_callback_data *) vdata; + struct train_params_common * params = data->params; + struct train_state * train = data->train; + struct ggml_opt_context * opt = train->opt; int n_batch = params->n_batch; int n_ctx = params->n_ctx; @@ -1385,11 +1064,11 @@ int main(int argc, char ** argv) { return 1; } - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); + if (params.common.seed == LLAMA_DEFAULT_SEED) { + params.common.seed = time(NULL); } - printf("%s: seed: %u\n", __func__, params.seed); - srand(params.seed); + printf("%s: seed: %u\n", __func__, params.common.seed); + srand(params.common.seed); struct llama_context_params llama_params = llama_context_default_params(); llama_params.vocab_only = true; @@ -1399,7 +1078,7 @@ int main(int argc, char ** argv) { struct my_llama_model model; model.hparams.n_vocab = llama_n_vocab(lctx); - model.hparams.n_ctx = params.n_ctx; + model.hparams.n_ctx = params.common.n_ctx; model.hparams.n_embd = params.n_embd; model.hparams.n_head = params.n_head; model.hparams.n_layer = params.n_layer; @@ -1421,34 +1100,34 @@ int main(int argc, char ** argv) { int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; - int n_batch = params.n_batch; + int n_batch = params.common.n_batch; - struct train_state * train = init_train_state(params.seed); + struct train_state * train = init_train_state(params.common.seed); struct ggml_opt_context * opt = train->opt; struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); opt_params_adam.print_forward_graph = false; opt_params_adam.print_backward_graph = false; - opt_params_adam.n_threads = params.n_threads; - opt_params_adam.past = params.opt_past; - opt_params_adam.delta = params.opt_delta; - opt_params_adam.max_no_improvement = params.opt_max_no_improvement; - opt_params_adam.n_gradient_accumulation = params.n_gradient_accumulation; - opt_params_adam.adam.n_iter = params.adam_n_iter; + opt_params_adam.n_threads = params.common.n_threads; + opt_params_adam.past = params.common.opt_past; + opt_params_adam.delta = params.common.opt_delta; + opt_params_adam.max_no_improvement = params.common.opt_max_no_improvement; + opt_params_adam.n_gradient_accumulation = params.common.n_gradient_accumulation; + opt_params_adam.adam.n_iter = params.common.adam_n_iter; opt_params_adam.adam.sched = 1.0f; - opt_params_adam.adam.alpha = params.adam_alpha; - opt_params_adam.adam.decay = params.adam_decay; - opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim; - opt_params_adam.adam.beta1 = params.adam_beta1; - opt_params_adam.adam.beta2 = params.adam_beta2; - opt_params_adam.adam.gclip = params.adam_gclip; - opt_params_adam.adam.eps_f = params.adam_eps_f; + opt_params_adam.adam.alpha = params.common.adam_alpha; + opt_params_adam.adam.decay = params.common.adam_decay; + opt_params_adam.adam.decay_min_ndim = params.common.adam_decay_min_ndim; + opt_params_adam.adam.beta1 = params.common.adam_beta1; + opt_params_adam.adam.beta2 = params.common.adam_beta2; + opt_params_adam.adam.gclip = params.common.adam_gclip; + opt_params_adam.adam.eps_f = params.common.adam_eps_f; opt->ctx = model.ctx; opt->params = opt_params_adam; printf("%s: init model\n", __func__); - bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, train); + bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train); if (!existed) { init_model(&model); } @@ -1461,7 +1140,7 @@ int main(int argc, char ** argv) { bool from_scratch = !existed; if (from_scratch) { - randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); + randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f); } printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx)); @@ -1485,10 +1164,10 @@ int main(int argc, char ** argv) { std::vector train_samples_size; printf("%s: tokenize training data\n", __func__); tokenize_file(lctx, - params.fn_train_data, - params.sample_start, - params.include_sample_start, - params.overlapping_samples, + params.common.fn_train_data, + params.common.sample_start, + params.common.include_sample_start, + params.common.overlapping_samples, n_tokens, train_tokens, train_samples_begin, @@ -1497,16 +1176,16 @@ int main(int argc, char ** argv) { printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size()); - size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); + size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size()); if (changed_train_data) { printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); } - if (params.force_reshuffle) { + if (params.common.force_reshuffle) { printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); } - if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { - train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed); + if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) { + train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed); train->shuffle_sample_count = train_samples_size.size(); train->shuffle_next_sample = 0; train->shuffle_samples_hash = shuffle_samples_hash; @@ -1525,15 +1204,15 @@ int main(int argc, char ** argv) { printf("%s: begin training\n", __func__); save_train_files_data save_data; - save_data.fn_checkpoint_out = params.fn_checkpoint_out; + save_data.fn_checkpoint_out = params.common.fn_checkpoint_out; save_data.fn_model_out = params.fn_model_out; save_data.fn_vocab_model = params.fn_vocab_model; - save_data.pattern_fn_it = params.pattern_fn_it; - save_data.fn_latest = params.fn_latest; + save_data.pattern_fn_it = params.common.pattern_fn_it; + save_data.fn_latest = params.common.fn_latest; save_data.model = &model; struct opt_callback_data opt_cb_data; - opt_cb_data.params = ¶ms; + opt_cb_data.params = ¶ms.common; opt_cb_data.train = train; opt_cb_data.save_cb = &save_train_files; opt_cb_data.save_data = &save_data; @@ -1587,7 +1266,7 @@ int main(int argc, char ** argv) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_cgraph * gb = ggml_new_graph(ctx0); - struct ggml_cgraph * gb_tmp = params.use_checkpointing + struct ggml_cgraph * gb_tmp = params.common.use_checkpointing ? ggml_new_graph(ctx0) : NULL; @@ -1601,21 +1280,21 @@ int main(int argc, char ** argv) { gf, gb, gb_tmp, &logits, tokens_input, target_probs, n_tokens, n_batch, - params.use_flash, - params.use_checkpointing + params.common.use_flash, + params.common.use_checkpointing ); size_t used_mem_before_opt = ggml_used_mem(ctx0); opt->params.adam.sched = learning_schedule( opt->iter, - params.warmup, - params.cos_decay_steps, - params.adam_alpha, - params.adam_min_alpha, - params.cos_decay_min, - params.cos_decay_restart, - params.enable_restart); + params.common.warmup, + params.common.cos_decay_steps, + params.common.adam_alpha, + params.common.adam_min_alpha, + params.common.cos_decay_min, + params.common.cos_decay_restart, + params.common.enable_restart); printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); @@ -1623,7 +1302,7 @@ int main(int argc, char ** argv) { size_t used_mem_after_opt = ggml_used_mem(ctx0); - int n_iter = params.adam_n_iter; + int n_iter = params.common.adam_n_iter; train->train_its = opt->iter; train->train_samples += n_batch * n_iter; train->train_tokens += n_batch * n_tokens * n_iter;