add command line option --rank-wo N for rank of wo tensor

This commit is contained in:
xaedes 2023-08-23 20:00:48 +02:00
parent 77a3092c83
commit 1a5f0a30e0
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2270,6 +2270,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq); fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq);
fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk); fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk);
fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv); fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv);
fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo);
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1); fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1);
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2); fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2);
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3); fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3);
@ -2465,6 +2466,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break; break;
} }
params->n_rank_wv = std::stoi(argv[i]); params->n_rank_wv = std::stoi(argv[i]);
} else if (arg == "--rank-wo") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rank_wo = std::stoi(argv[i]);
} else if (arg == "--rank-w1") { } else if (arg == "--rank-w1") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;