diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index f476813ab..260b89ce3 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2052,6 +2052,7 @@ struct train_params { int32_t lora_r; int32_t lora_alpha; + bool custom_lora_alpha; uint32_t n_rank_attention_norm; uint32_t n_rank_wq; @@ -2147,8 +2148,9 @@ struct train_params get_default_train_params() { params.custom_rope_freq_base = false; params.custom_rope_freq_scale = false; - params.lora_alpha = 4; params.lora_r = 4; + params.lora_alpha = 4; + params.custom_lora_alpha = false; params.n_rank_attention_norm = 1; params.n_rank_wq = 4; @@ -2406,6 +2408,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->lora_alpha = std::stoi(argv[i]); + params->custom_lora_alpha = true; } else if (arg == "--lora-r") { if (++i >= argc) { invalid_param = true; @@ -2909,7 +2912,7 @@ int main(int argc, char ** argv) { lora.hparams.rope_freq_scale = params.rope_freq_scale; } lora.hparams.lora_r = params.lora_r; - lora.hparams.lora_alpha = params.lora_alpha; + lora.hparams.lora_alpha = params.custom_lora_alpha ? params.lora_alpha : params.lora_r; int n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1; int n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r; int n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;