llama : implement YaRN RoPE scaling (#2268)
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com> Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
This commit is contained in:
parent
c43c2da8af
commit
898aeca90a
15 changed files with 763 additions and 257 deletions
|
@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf("options:\n");
|
||||
printf(" -h, --help show this help message and exit\n");
|
||||
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
|
||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
printf(" --rope-scaling {none,linear,yarn}\n");
|
||||
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
|
||||
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
|
||||
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
|
||||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
if (llama_mlock_supported())
|
||||
|
@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.n_ctx = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-scaling")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
||||
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
||||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
||||
else { invalid_param = true; break; }
|
||||
}
|
||||
else if (arg == "--rope-freq-base")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.rope_freq_scale = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-ext-factor")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_ext_factor = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-attn-factor")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_attn_factor = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-beta-fast")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_fast = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-beta-slow")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--memory-f32" || arg == "--memory_f32")
|
||||
{
|
||||
params.memory_f16 = false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue