specify default lora rank with '--lora-r N'
'--lora-r N' will specify default rank for all tensors '--rank-wq N', etc. will override this default rank for specific tensor types.
This commit is contained in:
parent
8c2d7e37f9
commit
c08fcf5947
1 changed files with 90 additions and 40 deletions
|
@ -1760,6 +1760,19 @@ struct train_params {
|
||||||
uint32_t n_rank_norm;
|
uint32_t n_rank_norm;
|
||||||
uint32_t n_rank_output;
|
uint32_t n_rank_output;
|
||||||
|
|
||||||
|
bool custom_n_rank_attention_norm;
|
||||||
|
bool custom_n_rank_wq;
|
||||||
|
bool custom_n_rank_wk;
|
||||||
|
bool custom_n_rank_wv;
|
||||||
|
bool custom_n_rank_wo;
|
||||||
|
bool custom_n_rank_ffn_norm;
|
||||||
|
bool custom_n_rank_w1;
|
||||||
|
bool custom_n_rank_w2;
|
||||||
|
bool custom_n_rank_w3;
|
||||||
|
bool custom_n_rank_tok_embeddings;
|
||||||
|
bool custom_n_rank_norm;
|
||||||
|
bool custom_n_rank_output;
|
||||||
|
|
||||||
bool samples_start_after_nl;
|
bool samples_start_after_nl;
|
||||||
bool use_adam;
|
bool use_adam;
|
||||||
bool use_flash;
|
bool use_flash;
|
||||||
|
@ -1835,6 +1848,19 @@ struct train_params get_default_train_params() {
|
||||||
params.n_rank_norm = 1;
|
params.n_rank_norm = 1;
|
||||||
params.n_rank_output = 4;
|
params.n_rank_output = 4;
|
||||||
|
|
||||||
|
params.custom_n_rank_attention_norm = false;
|
||||||
|
params.custom_n_rank_wq = false;
|
||||||
|
params.custom_n_rank_wk = false;
|
||||||
|
params.custom_n_rank_wv = false;
|
||||||
|
params.custom_n_rank_wo = false;
|
||||||
|
params.custom_n_rank_ffn_norm = false;
|
||||||
|
params.custom_n_rank_w1 = false;
|
||||||
|
params.custom_n_rank_w2 = false;
|
||||||
|
params.custom_n_rank_w3 = false;
|
||||||
|
params.custom_n_rank_tok_embeddings = false;
|
||||||
|
params.custom_n_rank_norm = false;
|
||||||
|
params.custom_n_rank_output = false;
|
||||||
|
|
||||||
params.samples_start_after_nl = false;
|
params.samples_start_after_nl = false;
|
||||||
params.use_adam = true;
|
params.use_adam = true;
|
||||||
params.use_flash = true;
|
params.use_flash = true;
|
||||||
|
@ -1887,19 +1913,19 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
||||||
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
||||||
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
|
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
|
||||||
fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r);
|
fprintf(stderr, " --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r);
|
||||||
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm);
|
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
|
||||||
fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor (default %d)\n", params->n_rank_ffn_norm);
|
fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
|
||||||
fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor (default %d)\n", params->n_rank_norm);
|
fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
|
||||||
fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor (default %d)\n", params->n_rank_tok_embeddings);
|
fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-out N LORA rank for output tensor (default %d)\n", params->n_rank_output);
|
fprintf(stderr, " --rank-out N LORA rank for output tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-wq N LORA rank for wq tensor (default %d)\n", params->n_rank_wq);
|
fprintf(stderr, " --rank-wq N LORA rank for wq tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-wk N LORA rank for wk tensor (default %d)\n", params->n_rank_wk);
|
fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-wv N LORA rank for wv tensor (default %d)\n", params->n_rank_wv);
|
fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-wo N LORA rank for wo tensor (default %d)\n", params->n_rank_wo);
|
fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor (default %d)\n", params->n_rank_w1);
|
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2);
|
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3);
|
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
|
||||||
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
||||||
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
||||||
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
||||||
|
@ -2063,72 +2089,84 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_attention_norm = std::stoi(argv[i]);
|
params->n_rank_attention_norm = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_attention_norm = true;
|
||||||
} else if (arg == "--rank-ffn-norm") {
|
} else if (arg == "--rank-ffn-norm") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_ffn_norm = std::stoi(argv[i]);
|
params->n_rank_ffn_norm = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_ffn_norm = true;
|
||||||
} else if (arg == "--rank-out-norm") {
|
} else if (arg == "--rank-out-norm") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_norm = std::stoi(argv[i]);
|
params->n_rank_norm = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_norm = true;
|
||||||
} else if (arg == "--rank-tok-embd") {
|
} else if (arg == "--rank-tok-embd") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_tok_embeddings = std::stoi(argv[i]);
|
params->n_rank_tok_embeddings = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_tok_embeddings = true;
|
||||||
} else if (arg == "--rank-out") {
|
} else if (arg == "--rank-out") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_output = std::stoi(argv[i]);
|
params->n_rank_output = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_output = true;
|
||||||
} else if (arg == "--rank-wq") {
|
} else if (arg == "--rank-wq") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_wq = std::stoi(argv[i]);
|
params->n_rank_wq = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_wq = true;
|
||||||
} else if (arg == "--rank-wk") {
|
} else if (arg == "--rank-wk") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_wk = std::stoi(argv[i]);
|
params->n_rank_wk = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_wk = true;
|
||||||
} else if (arg == "--rank-wv") {
|
} else if (arg == "--rank-wv") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_wv = std::stoi(argv[i]);
|
params->n_rank_wv = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_wv = true;
|
||||||
} else if (arg == "--rank-wo") {
|
} else if (arg == "--rank-wo") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_wo = std::stoi(argv[i]);
|
params->n_rank_wo = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_wo = true;
|
||||||
} else if (arg == "--rank-w1") {
|
} else if (arg == "--rank-w1") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_w1 = std::stoi(argv[i]);
|
params->n_rank_w1 = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_w1 = true;
|
||||||
} else if (arg == "--rank-w2") {
|
} else if (arg == "--rank-w2") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_w2 = std::stoi(argv[i]);
|
params->n_rank_w2 = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_w2 = true;
|
||||||
} else if (arg == "--rank-w3") {
|
} else if (arg == "--rank-w3") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_rank_w3 = std::stoi(argv[i]);
|
params->n_rank_w3 = std::stoi(argv[i]);
|
||||||
|
params->custom_n_rank_w3 = true;
|
||||||
} else if (arg == "--samples-after-nl") {
|
} else if (arg == "--samples-after-nl") {
|
||||||
params->samples_start_after_nl = true;
|
params->samples_start_after_nl = true;
|
||||||
} else if (arg == "--use-lbfgs") {
|
} else if (arg == "--use-lbfgs") {
|
||||||
|
@ -2494,18 +2532,30 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
lora.hparams.lora_r = params.lora_r;
|
lora.hparams.lora_r = params.lora_r;
|
||||||
lora.hparams.lora_alpha = params.lora_alpha;
|
lora.hparams.lora_alpha = params.lora_alpha;
|
||||||
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;
|
int n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
|
||||||
lora.hparams.n_rank_wq = params.n_rank_wq;
|
int n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r;
|
||||||
lora.hparams.n_rank_wk = params.n_rank_wk;
|
int n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;
|
||||||
lora.hparams.n_rank_wv = params.n_rank_wv;
|
int n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r;
|
||||||
lora.hparams.n_rank_wo = params.n_rank_wo;
|
int n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r;
|
||||||
lora.hparams.n_rank_ffn_norm = params.n_rank_ffn_norm;
|
int n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1;
|
||||||
lora.hparams.n_rank_w1 = params.n_rank_w1;
|
int n_rank_w1 = params.custom_n_rank_w1 ? params.n_rank_w1 : params.lora_r;
|
||||||
lora.hparams.n_rank_w2 = params.n_rank_w2;
|
int n_rank_w2 = params.custom_n_rank_w2 ? params.n_rank_w2 : params.lora_r;
|
||||||
lora.hparams.n_rank_w3 = params.n_rank_w3;
|
int n_rank_w3 = params.custom_n_rank_w3 ? params.n_rank_w3 : params.lora_r;
|
||||||
lora.hparams.n_rank_tok_embeddings = params.n_rank_tok_embeddings;
|
int n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
|
||||||
lora.hparams.n_rank_norm = params.n_rank_norm;
|
int n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1;
|
||||||
lora.hparams.n_rank_output = params.n_rank_output;
|
int n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r;
|
||||||
|
lora.hparams.n_rank_attention_norm = n_rank_attention_norm;
|
||||||
|
lora.hparams.n_rank_wq = n_rank_wq;
|
||||||
|
lora.hparams.n_rank_wk = n_rank_wk;
|
||||||
|
lora.hparams.n_rank_wv = n_rank_wv;
|
||||||
|
lora.hparams.n_rank_wo = n_rank_wo;
|
||||||
|
lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm;
|
||||||
|
lora.hparams.n_rank_w1 = n_rank_w1;
|
||||||
|
lora.hparams.n_rank_w2 = n_rank_w2;
|
||||||
|
lora.hparams.n_rank_w3 = n_rank_w3;
|
||||||
|
lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
|
||||||
|
lora.hparams.n_rank_norm = n_rank_norm;
|
||||||
|
lora.hparams.n_rank_output = n_rank_output;
|
||||||
|
|
||||||
// set opt params from command line
|
// set opt params from command line
|
||||||
if (params.use_adam) {
|
if (params.use_adam) {
|
||||||
|
@ -2550,30 +2600,30 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool opt_param_count_changed = (
|
const bool opt_param_count_changed = (
|
||||||
(lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm)
|
(lora.hparams.n_rank_attention_norm != n_rank_attention_norm)
|
||||||
|| (lora.hparams.n_rank_wq != params.n_rank_wq)
|
|| (lora.hparams.n_rank_wq != n_rank_wq)
|
||||||
|| (lora.hparams.n_rank_wk != params.n_rank_wk)
|
|| (lora.hparams.n_rank_wk != n_rank_wk)
|
||||||
|| (lora.hparams.n_rank_wv != params.n_rank_wv)
|
|| (lora.hparams.n_rank_wv != n_rank_wv)
|
||||||
|| (lora.hparams.n_rank_wo != params.n_rank_wo)
|
|| (lora.hparams.n_rank_wo != n_rank_wo)
|
||||||
|| (lora.hparams.n_rank_ffn_norm != params.n_rank_ffn_norm)
|
|| (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm)
|
||||||
|| (lora.hparams.n_rank_w1 != params.n_rank_w1)
|
|| (lora.hparams.n_rank_w1 != n_rank_w1)
|
||||||
|| (lora.hparams.n_rank_w2 != params.n_rank_w2)
|
|| (lora.hparams.n_rank_w2 != n_rank_w2)
|
||||||
|| (lora.hparams.n_rank_w3 != params.n_rank_w3)
|
|| (lora.hparams.n_rank_w3 != n_rank_w3)
|
||||||
|| (lora.hparams.n_rank_tok_embeddings != params.n_rank_tok_embeddings)
|
|| (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
|
||||||
|| (lora.hparams.n_rank_norm != params.n_rank_norm)
|
|| (lora.hparams.n_rank_norm != n_rank_norm)
|
||||||
|| (lora.hparams.n_rank_output != params.n_rank_output)
|
|| (lora.hparams.n_rank_output != n_rank_output)
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool opt_past_changed = opt->params.past != params.opt_past;
|
const bool opt_past_changed = opt->params.past != params.opt_past;
|
||||||
|
|
||||||
GGML_ASSERT(opt_param_count_changed == false);
|
|
||||||
GGML_ASSERT(opt_past_changed == false);
|
|
||||||
|
|
||||||
if (opt_param_count_changed) {
|
if (opt_param_count_changed) {
|
||||||
|
print_lora_params(&lora.hparams);
|
||||||
|
GGML_ASSERT(!"Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting.");
|
||||||
// need to discard previous optimizer gradient statistics and opt_init with new shapes
|
// need to discard previous optimizer gradient statistics and opt_init with new shapes
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
if (opt_past_changed) {
|
if (opt_past_changed) {
|
||||||
|
GGML_ASSERT(!"Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
|
||||||
// need to discard previous optimizer past function value statistics and opt_init with new shapes
|
// need to discard previous optimizer past function value statistics and opt_init with new shapes
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue