Add --rope-scale parameter (#2544)

* common.cpp : Add --rope-scale parameter
* README.md : Add info about using linear rope scaling
This commit is contained in:
klosax 2023-08-07 19:07:19 +02:00 committed by GitHub
parent 93356bdb7a
commit f3c3b4b167
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 2 deletions

View file

@ -194,6 +194,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
@ -564,8 +570,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");