set lora_alpha to value of lora_r if it is not set via command line

otherwise only changing lora_r will change scaling of lora adapter used in prediction
This commit is contained in:
xaedes 2023-09-14 17:58:31 +02:00
parent 20cf1a4589
commit 3a9c1d7f5a
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2052,6 +2052,7 @@ struct train_params {
int32_t lora_r;
int32_t lora_alpha;
bool custom_lora_alpha;
uint32_t n_rank_attention_norm;
uint32_t n_rank_wq;
@ -2147,8 +2148,9 @@ struct train_params get_default_train_params() {
params.custom_rope_freq_base = false;
params.custom_rope_freq_scale = false;
params.lora_alpha = 4;
params.lora_r = 4;
params.lora_alpha = 4;
params.custom_lora_alpha = false;
params.n_rank_attention_norm = 1;
params.n_rank_wq = 4;
@ -2406,6 +2408,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->lora_alpha = std::stoi(argv[i]);
params->custom_lora_alpha = true;
} else if (arg == "--lora-r") {
if (++i >= argc) {
invalid_param = true;
@ -2909,7 +2912,7 @@ int main(int argc, char ** argv) {
lora.hparams.rope_freq_scale = params.rope_freq_scale;
}
lora.hparams.lora_r = params.lora_r;
lora.hparams.lora_alpha = params.lora_alpha;
lora.hparams.lora_alpha = params.custom_lora_alpha ? params.lora_alpha : params.lora_r;
int n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
int n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r;
int n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;