From 1a5f0a30e0eac106a08ab1c40d5b603dfecbd3f0 Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 23 Aug 2023 20:00:48 +0200 Subject: [PATCH] add command line option `--rank-wo N` for rank of wo tensor --- examples/finetune/finetune.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index f4beb59e2..77d0cacda 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2270,6 +2270,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq); fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk); fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv); + fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo); fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1); fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2); fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3); @@ -2465,6 +2466,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_rank_wv = std::stoi(argv[i]); + } else if (arg == "--rank-wo") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_rank_wo = std::stoi(argv[i]); } else if (arg == "--rank-w1") { if (++i >= argc) { invalid_param = true;