llama : do not use KV cache for non-causal models
ggml-ci
This commit is contained in:
parent
d0347840c1
commit
eb42596277
3 changed files with 109 additions and 39 deletions
|
@ -13,7 +13,7 @@ async def main():
|
|||
model_url = "http://127.0.0.1:6900"
|
||||
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||
url= f"{model_url}/embedding",
|
||||
json= {"content": str(i)*32}
|
||||
json= {"content": str(0)*32}
|
||||
) for i in range(n)])
|
||||
|
||||
for response in responses:
|
||||
|
|
|
@ -2044,6 +2044,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" --pooling {none,mean,cls}\n");
|
||||
printf(" pooling type for embeddings, use model default if unspecified\n");
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
|
@ -2284,6 +2286,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--pooling")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||
else { invalid_param = true; break; }
|
||||
}
|
||||
else if (arg == "--threads" || arg == "-t")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue