diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index f4beb59e2..77d0cacda 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2270,6 +2270,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq); fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk); fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv); + fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo); fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1); fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2); fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3); @@ -2465,6 +2466,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_rank_wv = std::stoi(argv[i]); + } else if (arg == "--rank-wo") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_rank_wo = std::stoi(argv[i]); } else if (arg == "--rank-w1") { if (++i >= argc) { invalid_param = true;