diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index a2ad711fb..632392970 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1760,6 +1760,19 @@ struct train_params { uint32_t n_rank_norm; uint32_t n_rank_output; + bool custom_n_rank_attention_norm; + bool custom_n_rank_wq; + bool custom_n_rank_wk; + bool custom_n_rank_wv; + bool custom_n_rank_wo; + bool custom_n_rank_ffn_norm; + bool custom_n_rank_w1; + bool custom_n_rank_w2; + bool custom_n_rank_w3; + bool custom_n_rank_tok_embeddings; + bool custom_n_rank_norm; + bool custom_n_rank_output; + bool samples_start_after_nl; bool use_adam; bool use_flash; @@ -1835,6 +1848,19 @@ struct train_params get_default_train_params() { params.n_rank_norm = 1; params.n_rank_output = 4; + params.custom_n_rank_attention_norm = false; + params.custom_n_rank_wq = false; + params.custom_n_rank_wk = false; + params.custom_n_rank_wv = false; + params.custom_n_rank_wo = false; + params.custom_n_rank_ffn_norm = false; + params.custom_n_rank_w1 = false; + params.custom_n_rank_w2 = false; + params.custom_n_rank_w3 = false; + params.custom_n_rank_tok_embeddings = false; + params.custom_n_rank_norm = false; + params.custom_n_rank_output = false; + params.samples_start_after_nl = false; params.use_adam = true; params.use_flash = true; @@ -1887,19 +1913,19 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha); - fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r); - fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm); - fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor (default %d)\n", params->n_rank_ffn_norm); - fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor (default %d)\n", params->n_rank_norm); - fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor (default %d)\n", params->n_rank_tok_embeddings); - fprintf(stderr, " --rank-out N LORA rank for output tensor (default %d)\n", params->n_rank_output); - fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq); - fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk); - fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv); - fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo); - fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1); - fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2); - fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3); + fprintf(stderr, " --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r); + fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n"); + fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n"); + fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n"); + fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-out N LORA rank for output tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-wq N LORA rank for wq tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n"); + fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n"); fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); @@ -2063,72 +2089,84 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_rank_attention_norm = std::stoi(argv[i]); + params->custom_n_rank_attention_norm = true; } else if (arg == "--rank-ffn-norm") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_ffn_norm = std::stoi(argv[i]); + params->custom_n_rank_ffn_norm = true; } else if (arg == "--rank-out-norm") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_norm = std::stoi(argv[i]); + params->custom_n_rank_norm = true; } else if (arg == "--rank-tok-embd") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_tok_embeddings = std::stoi(argv[i]); + params->custom_n_rank_tok_embeddings = true; } else if (arg == "--rank-out") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_output = std::stoi(argv[i]); + params->custom_n_rank_output = true; } else if (arg == "--rank-wq") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_wq = std::stoi(argv[i]); + params->custom_n_rank_wq = true; } else if (arg == "--rank-wk") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_wk = std::stoi(argv[i]); + params->custom_n_rank_wk = true; } else if (arg == "--rank-wv") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_wv = std::stoi(argv[i]); + params->custom_n_rank_wv = true; } else if (arg == "--rank-wo") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_wo = std::stoi(argv[i]); + params->custom_n_rank_wo = true; } else if (arg == "--rank-w1") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_w1 = std::stoi(argv[i]); + params->custom_n_rank_w1 = true; } else if (arg == "--rank-w2") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_w2 = std::stoi(argv[i]); + params->custom_n_rank_w2 = true; } else if (arg == "--rank-w3") { if (++i >= argc) { invalid_param = true; break; } params->n_rank_w3 = std::stoi(argv[i]); + params->custom_n_rank_w3 = true; } else if (arg == "--samples-after-nl") { params->samples_start_after_nl = true; } else if (arg == "--use-lbfgs") { @@ -2494,18 +2532,30 @@ int main(int argc, char ** argv) { } lora.hparams.lora_r = params.lora_r; lora.hparams.lora_alpha = params.lora_alpha; - lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm; - lora.hparams.n_rank_wq = params.n_rank_wq; - lora.hparams.n_rank_wk = params.n_rank_wk; - lora.hparams.n_rank_wv = params.n_rank_wv; - lora.hparams.n_rank_wo = params.n_rank_wo; - lora.hparams.n_rank_ffn_norm = params.n_rank_ffn_norm; - lora.hparams.n_rank_w1 = params.n_rank_w1; - lora.hparams.n_rank_w2 = params.n_rank_w2; - lora.hparams.n_rank_w3 = params.n_rank_w3; - lora.hparams.n_rank_tok_embeddings = params.n_rank_tok_embeddings; - lora.hparams.n_rank_norm = params.n_rank_norm; - lora.hparams.n_rank_output = params.n_rank_output; + int n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1; + int n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r; + int n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r; + int n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r; + int n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r; + int n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1; + int n_rank_w1 = params.custom_n_rank_w1 ? params.n_rank_w1 : params.lora_r; + int n_rank_w2 = params.custom_n_rank_w2 ? params.n_rank_w2 : params.lora_r; + int n_rank_w3 = params.custom_n_rank_w3 ? params.n_rank_w3 : params.lora_r; + int n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r; + int n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1; + int n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r; + lora.hparams.n_rank_attention_norm = n_rank_attention_norm; + lora.hparams.n_rank_wq = n_rank_wq; + lora.hparams.n_rank_wk = n_rank_wk; + lora.hparams.n_rank_wv = n_rank_wv; + lora.hparams.n_rank_wo = n_rank_wo; + lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm; + lora.hparams.n_rank_w1 = n_rank_w1; + lora.hparams.n_rank_w2 = n_rank_w2; + lora.hparams.n_rank_w3 = n_rank_w3; + lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings; + lora.hparams.n_rank_norm = n_rank_norm; + lora.hparams.n_rank_output = n_rank_output; // set opt params from command line if (params.use_adam) { @@ -2550,30 +2600,30 @@ int main(int argc, char ** argv) { } const bool opt_param_count_changed = ( - (lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm) - || (lora.hparams.n_rank_wq != params.n_rank_wq) - || (lora.hparams.n_rank_wk != params.n_rank_wk) - || (lora.hparams.n_rank_wv != params.n_rank_wv) - || (lora.hparams.n_rank_wo != params.n_rank_wo) - || (lora.hparams.n_rank_ffn_norm != params.n_rank_ffn_norm) - || (lora.hparams.n_rank_w1 != params.n_rank_w1) - || (lora.hparams.n_rank_w2 != params.n_rank_w2) - || (lora.hparams.n_rank_w3 != params.n_rank_w3) - || (lora.hparams.n_rank_tok_embeddings != params.n_rank_tok_embeddings) - || (lora.hparams.n_rank_norm != params.n_rank_norm) - || (lora.hparams.n_rank_output != params.n_rank_output) + (lora.hparams.n_rank_attention_norm != n_rank_attention_norm) + || (lora.hparams.n_rank_wq != n_rank_wq) + || (lora.hparams.n_rank_wk != n_rank_wk) + || (lora.hparams.n_rank_wv != n_rank_wv) + || (lora.hparams.n_rank_wo != n_rank_wo) + || (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm) + || (lora.hparams.n_rank_w1 != n_rank_w1) + || (lora.hparams.n_rank_w2 != n_rank_w2) + || (lora.hparams.n_rank_w3 != n_rank_w3) + || (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings) + || (lora.hparams.n_rank_norm != n_rank_norm) + || (lora.hparams.n_rank_output != n_rank_output) ); const bool opt_past_changed = opt->params.past != params.opt_past; - GGML_ASSERT(opt_param_count_changed == false); - GGML_ASSERT(opt_past_changed == false); - if (opt_param_count_changed) { + print_lora_params(&lora.hparams); + GGML_ASSERT(!"Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting."); // need to discard previous optimizer gradient statistics and opt_init with new shapes // TODO } if (opt_past_changed) { + GGML_ASSERT(!"Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting"); // need to discard previous optimizer past function value statistics and opt_init with new shapes // TODO }