specify default lora rank with '--lora-r N'

'--lora-r N' will specify default rank for all tensors
'--rank-wq N', etc. will override this default rank for specific tensor types.
This commit is contained in:
xaedes 2023-09-06 20:11:22 +02:00
parent 8c2d7e37f9
commit c08fcf5947
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1760,6 +1760,19 @@ struct train_params {
uint32_t n_rank_norm; uint32_t n_rank_norm;
uint32_t n_rank_output; uint32_t n_rank_output;
bool custom_n_rank_attention_norm;
bool custom_n_rank_wq;
bool custom_n_rank_wk;
bool custom_n_rank_wv;
bool custom_n_rank_wo;
bool custom_n_rank_ffn_norm;
bool custom_n_rank_w1;
bool custom_n_rank_w2;
bool custom_n_rank_w3;
bool custom_n_rank_tok_embeddings;
bool custom_n_rank_norm;
bool custom_n_rank_output;
bool samples_start_after_nl; bool samples_start_after_nl;
bool use_adam; bool use_adam;
bool use_flash; bool use_flash;
@ -1835,6 +1848,19 @@ struct train_params get_default_train_params() {
params.n_rank_norm = 1; params.n_rank_norm = 1;
params.n_rank_output = 4; params.n_rank_output = 4;
params.custom_n_rank_attention_norm = false;
params.custom_n_rank_wq = false;
params.custom_n_rank_wk = false;
params.custom_n_rank_wv = false;
params.custom_n_rank_wo = false;
params.custom_n_rank_ffn_norm = false;
params.custom_n_rank_w1 = false;
params.custom_n_rank_w2 = false;
params.custom_n_rank_w3 = false;
params.custom_n_rank_tok_embeddings = false;
params.custom_n_rank_norm = false;
params.custom_n_rank_output = false;
params.samples_start_after_nl = false; params.samples_start_after_nl = false;
params.use_adam = true; params.use_adam = true;
params.use_flash = true; params.use_flash = true;
@ -1887,19 +1913,19 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha); fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r); fprintf(stderr, " --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r);
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm); fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor (default %d)\n", params->n_rank_ffn_norm); fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor (default %d)\n", params->n_rank_norm); fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor (default %d)\n", params->n_rank_tok_embeddings); fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor, overrides default rank.\n");
fprintf(stderr, " --rank-out N LORA rank for output tensor (default %d)\n", params->n_rank_output); fprintf(stderr, " --rank-out N LORA rank for output tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq); fprintf(stderr, " --rank-wq N LORA rank for wq tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk); fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv); fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo); fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1); fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2); fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3); fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
@ -2063,72 +2089,84 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break; break;
} }
params->n_rank_attention_norm = std::stoi(argv[i]); params->n_rank_attention_norm = std::stoi(argv[i]);
params->custom_n_rank_attention_norm = true;
} else if (arg == "--rank-ffn-norm") { } else if (arg == "--rank-ffn-norm") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_ffn_norm = std::stoi(argv[i]); params->n_rank_ffn_norm = std::stoi(argv[i]);
params->custom_n_rank_ffn_norm = true;
} else if (arg == "--rank-out-norm") { } else if (arg == "--rank-out-norm") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_norm = std::stoi(argv[i]); params->n_rank_norm = std::stoi(argv[i]);
params->custom_n_rank_norm = true;
} else if (arg == "--rank-tok-embd") { } else if (arg == "--rank-tok-embd") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_tok_embeddings = std::stoi(argv[i]); params->n_rank_tok_embeddings = std::stoi(argv[i]);
params->custom_n_rank_tok_embeddings = true;
} else if (arg == "--rank-out") { } else if (arg == "--rank-out") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_output = std::stoi(argv[i]); params->n_rank_output = std::stoi(argv[i]);
params->custom_n_rank_output = true;
} else if (arg == "--rank-wq") { } else if (arg == "--rank-wq") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_wq = std::stoi(argv[i]); params->n_rank_wq = std::stoi(argv[i]);
params->custom_n_rank_wq = true;
} else if (arg == "--rank-wk") { } else if (arg == "--rank-wk") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_wk = std::stoi(argv[i]); params->n_rank_wk = std::stoi(argv[i]);
params->custom_n_rank_wk = true;
} else if (arg == "--rank-wv") { } else if (arg == "--rank-wv") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_wv = std::stoi(argv[i]); params->n_rank_wv = std::stoi(argv[i]);
params->custom_n_rank_wv = true;
} else if (arg == "--rank-wo") { } else if (arg == "--rank-wo") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_wo = std::stoi(argv[i]); params->n_rank_wo = std::stoi(argv[i]);
params->custom_n_rank_wo = true;
} else if (arg == "--rank-w1") { } else if (arg == "--rank-w1") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_w1 = std::stoi(argv[i]); params->n_rank_w1 = std::stoi(argv[i]);
params->custom_n_rank_w1 = true;
} else if (arg == "--rank-w2") { } else if (arg == "--rank-w2") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_w2 = std::stoi(argv[i]); params->n_rank_w2 = std::stoi(argv[i]);
params->custom_n_rank_w2 = true;
} else if (arg == "--rank-w3") { } else if (arg == "--rank-w3") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params->n_rank_w3 = std::stoi(argv[i]); params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true;
} else if (arg == "--samples-after-nl") { } else if (arg == "--samples-after-nl") {
params->samples_start_after_nl = true; params->samples_start_after_nl = true;
} else if (arg == "--use-lbfgs") { } else if (arg == "--use-lbfgs") {
@ -2494,18 +2532,30 @@ int main(int argc, char ** argv) {
} }
lora.hparams.lora_r = params.lora_r; lora.hparams.lora_r = params.lora_r;
lora.hparams.lora_alpha = params.lora_alpha; lora.hparams.lora_alpha = params.lora_alpha;
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm; int n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
lora.hparams.n_rank_wq = params.n_rank_wq; int n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r;
lora.hparams.n_rank_wk = params.n_rank_wk; int n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;
lora.hparams.n_rank_wv = params.n_rank_wv; int n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r;
lora.hparams.n_rank_wo = params.n_rank_wo; int n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r;
lora.hparams.n_rank_ffn_norm = params.n_rank_ffn_norm; int n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1;
lora.hparams.n_rank_w1 = params.n_rank_w1; int n_rank_w1 = params.custom_n_rank_w1 ? params.n_rank_w1 : params.lora_r;
lora.hparams.n_rank_w2 = params.n_rank_w2; int n_rank_w2 = params.custom_n_rank_w2 ? params.n_rank_w2 : params.lora_r;
lora.hparams.n_rank_w3 = params.n_rank_w3; int n_rank_w3 = params.custom_n_rank_w3 ? params.n_rank_w3 : params.lora_r;
lora.hparams.n_rank_tok_embeddings = params.n_rank_tok_embeddings; int n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
lora.hparams.n_rank_norm = params.n_rank_norm; int n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1;
lora.hparams.n_rank_output = params.n_rank_output; int n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r;
lora.hparams.n_rank_attention_norm = n_rank_attention_norm;
lora.hparams.n_rank_wq = n_rank_wq;
lora.hparams.n_rank_wk = n_rank_wk;
lora.hparams.n_rank_wv = n_rank_wv;
lora.hparams.n_rank_wo = n_rank_wo;
lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm;
lora.hparams.n_rank_w1 = n_rank_w1;
lora.hparams.n_rank_w2 = n_rank_w2;
lora.hparams.n_rank_w3 = n_rank_w3;
lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
lora.hparams.n_rank_norm = n_rank_norm;
lora.hparams.n_rank_output = n_rank_output;
// set opt params from command line // set opt params from command line
if (params.use_adam) { if (params.use_adam) {
@ -2550,30 +2600,30 @@ int main(int argc, char ** argv) {
} }
const bool opt_param_count_changed = ( const bool opt_param_count_changed = (
(lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm) (lora.hparams.n_rank_attention_norm != n_rank_attention_norm)
|| (lora.hparams.n_rank_wq != params.n_rank_wq) || (lora.hparams.n_rank_wq != n_rank_wq)
|| (lora.hparams.n_rank_wk != params.n_rank_wk) || (lora.hparams.n_rank_wk != n_rank_wk)
|| (lora.hparams.n_rank_wv != params.n_rank_wv) || (lora.hparams.n_rank_wv != n_rank_wv)
|| (lora.hparams.n_rank_wo != params.n_rank_wo) || (lora.hparams.n_rank_wo != n_rank_wo)
|| (lora.hparams.n_rank_ffn_norm != params.n_rank_ffn_norm) || (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm)
|| (lora.hparams.n_rank_w1 != params.n_rank_w1) || (lora.hparams.n_rank_w1 != n_rank_w1)
|| (lora.hparams.n_rank_w2 != params.n_rank_w2) || (lora.hparams.n_rank_w2 != n_rank_w2)
|| (lora.hparams.n_rank_w3 != params.n_rank_w3) || (lora.hparams.n_rank_w3 != n_rank_w3)
|| (lora.hparams.n_rank_tok_embeddings != params.n_rank_tok_embeddings) || (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
|| (lora.hparams.n_rank_norm != params.n_rank_norm) || (lora.hparams.n_rank_norm != n_rank_norm)
|| (lora.hparams.n_rank_output != params.n_rank_output) || (lora.hparams.n_rank_output != n_rank_output)
); );
const bool opt_past_changed = opt->params.past != params.opt_past; const bool opt_past_changed = opt->params.past != params.opt_past;
GGML_ASSERT(opt_param_count_changed == false);
GGML_ASSERT(opt_past_changed == false);
if (opt_param_count_changed) { if (opt_param_count_changed) {
print_lora_params(&lora.hparams);
GGML_ASSERT(!"Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting.");
// need to discard previous optimizer gradient statistics and opt_init with new shapes // need to discard previous optimizer gradient statistics and opt_init with new shapes
// TODO // TODO
} }
if (opt_past_changed) { if (opt_past_changed) {
GGML_ASSERT(!"Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
// need to discard previous optimizer past function value statistics and opt_init with new shapes // need to discard previous optimizer past function value statistics and opt_init with new shapes
// TODO // TODO
} }