specify default lora rank with '--lora-r N'

'--lora-r N' will specify default rank for all tensors
'--rank-wq N', etc. will override this default rank for specific tensor types.
This commit is contained in:
xaedes 2023-09-06 20:11:22 +02:00
parent 8c2d7e37f9
commit c08fcf5947
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1760,6 +1760,19 @@ struct train_params {
uint32_t n_rank_norm;
uint32_t n_rank_output;
bool custom_n_rank_attention_norm;
bool custom_n_rank_wq;
bool custom_n_rank_wk;
bool custom_n_rank_wv;
bool custom_n_rank_wo;
bool custom_n_rank_ffn_norm;
bool custom_n_rank_w1;
bool custom_n_rank_w2;
bool custom_n_rank_w3;
bool custom_n_rank_tok_embeddings;
bool custom_n_rank_norm;
bool custom_n_rank_output;
bool samples_start_after_nl;
bool use_adam;
bool use_flash;
@ -1835,6 +1848,19 @@ struct train_params get_default_train_params() {
params.n_rank_norm = 1;
params.n_rank_output = 4;
params.custom_n_rank_attention_norm = false;
params.custom_n_rank_wq = false;
params.custom_n_rank_wk = false;
params.custom_n_rank_wv = false;
params.custom_n_rank_wo = false;
params.custom_n_rank_ffn_norm = false;
params.custom_n_rank_w1 = false;
params.custom_n_rank_w2 = false;
params.custom_n_rank_w3 = false;
params.custom_n_rank_tok_embeddings = false;
params.custom_n_rank_norm = false;
params.custom_n_rank_output = false;
params.samples_start_after_nl = false;
params.use_adam = true;
params.use_flash = true;
@ -1887,19 +1913,19 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r);
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm);
fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor (default %d)\n", params->n_rank_ffn_norm);
fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor (default %d)\n", params->n_rank_norm);
fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor (default %d)\n", params->n_rank_tok_embeddings);
fprintf(stderr, " --rank-out N LORA rank for output tensor (default %d)\n", params->n_rank_output);
fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq);
fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk);
fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv);
fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo);
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1);
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2);
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3);
fprintf(stderr, " --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r);
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor, overrides default rank.\n");
fprintf(stderr, " --rank-out N LORA rank for output tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wq N LORA rank for wq tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n");
fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
@ -2063,72 +2089,84 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->n_rank_attention_norm = std::stoi(argv[i]);
params->custom_n_rank_attention_norm = true;
} else if (arg == "--rank-ffn-norm") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_ffn_norm = std::stoi(argv[i]);
params->custom_n_rank_ffn_norm = true;
} else if (arg == "--rank-out-norm") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_norm = std::stoi(argv[i]);
params->custom_n_rank_norm = true;
} else if (arg == "--rank-tok-embd") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_tok_embeddings = std::stoi(argv[i]);
params->custom_n_rank_tok_embeddings = true;
} else if (arg == "--rank-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_output = std::stoi(argv[i]);
params->custom_n_rank_output = true;
} else if (arg == "--rank-wq") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_wq = std::stoi(argv[i]);
params->custom_n_rank_wq = true;
} else if (arg == "--rank-wk") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_wk = std::stoi(argv[i]);
params->custom_n_rank_wk = true;
} else if (arg == "--rank-wv") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_wv = std::stoi(argv[i]);
params->custom_n_rank_wv = true;
} else if (arg == "--rank-wo") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_wo = std::stoi(argv[i]);
params->custom_n_rank_wo = true;
} else if (arg == "--rank-w1") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_w1 = std::stoi(argv[i]);
params->custom_n_rank_w1 = true;
} else if (arg == "--rank-w2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_w2 = std::stoi(argv[i]);
params->custom_n_rank_w2 = true;
} else if (arg == "--rank-w3") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true;
} else if (arg == "--samples-after-nl") {
params->samples_start_after_nl = true;
} else if (arg == "--use-lbfgs") {
@ -2494,18 +2532,30 @@ int main(int argc, char ** argv) {
}
lora.hparams.lora_r = params.lora_r;
lora.hparams.lora_alpha = params.lora_alpha;
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;
lora.hparams.n_rank_wq = params.n_rank_wq;
lora.hparams.n_rank_wk = params.n_rank_wk;
lora.hparams.n_rank_wv = params.n_rank_wv;
lora.hparams.n_rank_wo = params.n_rank_wo;
lora.hparams.n_rank_ffn_norm = params.n_rank_ffn_norm;
lora.hparams.n_rank_w1 = params.n_rank_w1;
lora.hparams.n_rank_w2 = params.n_rank_w2;
lora.hparams.n_rank_w3 = params.n_rank_w3;
lora.hparams.n_rank_tok_embeddings = params.n_rank_tok_embeddings;
lora.hparams.n_rank_norm = params.n_rank_norm;
lora.hparams.n_rank_output = params.n_rank_output;
int n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
int n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r;
int n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;
int n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r;
int n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r;
int n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1;
int n_rank_w1 = params.custom_n_rank_w1 ? params.n_rank_w1 : params.lora_r;
int n_rank_w2 = params.custom_n_rank_w2 ? params.n_rank_w2 : params.lora_r;
int n_rank_w3 = params.custom_n_rank_w3 ? params.n_rank_w3 : params.lora_r;
int n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
int n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1;
int n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r;
lora.hparams.n_rank_attention_norm = n_rank_attention_norm;
lora.hparams.n_rank_wq = n_rank_wq;
lora.hparams.n_rank_wk = n_rank_wk;
lora.hparams.n_rank_wv = n_rank_wv;
lora.hparams.n_rank_wo = n_rank_wo;
lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm;
lora.hparams.n_rank_w1 = n_rank_w1;
lora.hparams.n_rank_w2 = n_rank_w2;
lora.hparams.n_rank_w3 = n_rank_w3;
lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
lora.hparams.n_rank_norm = n_rank_norm;
lora.hparams.n_rank_output = n_rank_output;
// set opt params from command line
if (params.use_adam) {
@ -2550,30 +2600,30 @@ int main(int argc, char ** argv) {
}
const bool opt_param_count_changed = (
(lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm)
|| (lora.hparams.n_rank_wq != params.n_rank_wq)
|| (lora.hparams.n_rank_wk != params.n_rank_wk)
|| (lora.hparams.n_rank_wv != params.n_rank_wv)
|| (lora.hparams.n_rank_wo != params.n_rank_wo)
|| (lora.hparams.n_rank_ffn_norm != params.n_rank_ffn_norm)
|| (lora.hparams.n_rank_w1 != params.n_rank_w1)
|| (lora.hparams.n_rank_w2 != params.n_rank_w2)
|| (lora.hparams.n_rank_w3 != params.n_rank_w3)
|| (lora.hparams.n_rank_tok_embeddings != params.n_rank_tok_embeddings)
|| (lora.hparams.n_rank_norm != params.n_rank_norm)
|| (lora.hparams.n_rank_output != params.n_rank_output)
(lora.hparams.n_rank_attention_norm != n_rank_attention_norm)
|| (lora.hparams.n_rank_wq != n_rank_wq)
|| (lora.hparams.n_rank_wk != n_rank_wk)
|| (lora.hparams.n_rank_wv != n_rank_wv)
|| (lora.hparams.n_rank_wo != n_rank_wo)
|| (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm)
|| (lora.hparams.n_rank_w1 != n_rank_w1)
|| (lora.hparams.n_rank_w2 != n_rank_w2)
|| (lora.hparams.n_rank_w3 != n_rank_w3)
|| (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
|| (lora.hparams.n_rank_norm != n_rank_norm)
|| (lora.hparams.n_rank_output != n_rank_output)
);
const bool opt_past_changed = opt->params.past != params.opt_past;
GGML_ASSERT(opt_param_count_changed == false);
GGML_ASSERT(opt_past_changed == false);
if (opt_param_count_changed) {
print_lora_params(&lora.hparams);
GGML_ASSERT(!"Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting.");
// need to discard previous optimizer gradient statistics and opt_init with new shapes
// TODO
}
if (opt_past_changed) {
GGML_ASSERT(!"Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
// need to discard previous optimizer past function value statistics and opt_init with new shapes
// TODO
}