server : disable cached prompts with self-extend

This commit is contained in:
Georgi Gerganov 2024-03-06 18:51:40 +02:00
parent 61b63705dc
commit aef02b11ec
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 19 additions and 31 deletions

View file

@ -13,7 +13,7 @@ async def main():
model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding",
json= {"content": str(i)*1024}
json= {"content": str(0)*1024}
) for i in range(n)])
for response in responses:

View file

@ -816,6 +816,11 @@ struct llama_server_context {
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
if (slot.params.cache_prompt && slot.ga_n != 1) {
LOG_WARNING("cache_prompt is not supported with group-attention", {});
slot.params.cache_prompt = false;
}
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
@ -1769,6 +1774,8 @@ struct llama_server_context {
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
} else {
GGML_ASSERT(slot.ga_n == 1);
// push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens) {
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
@ -1783,22 +1790,6 @@ struct llama_server_context {
}
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
if (slot.ga_n != 1) {
int ga_i = 0;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
int32_t slot_npast = 0;
for (int k = 0; k < slot.n_past; ++k) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
slot_npast++;
}
slot.n_past_se = slot_npast;
slot.ga_i = ga_i;
}
LOG_INFO("slot progression", {
@ -1809,7 +1800,6 @@ struct llama_server_context {
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
}
slot.cache_tokens = prompt_tokens;
@ -1841,15 +1831,13 @@ struct llama_server_context {
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});
std::vector<llama_token> prefix_tokens = prompt_tokens;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) {
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
@ -1858,7 +1846,7 @@ struct llama_server_context {
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++;
}