server : disable cached prompts with self-extend

This commit is contained in:
Georgi Gerganov 2024-03-06 18:51:40 +02:00
parent 61b63705dc
commit aef02b11ec
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 19 additions and 31 deletions

View file

@ -13,7 +13,7 @@ async def main():
model_url = "http://127.0.0.1:6900" model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding", url= f"{model_url}/embedding",
json= {"content": str(i)*1024} json= {"content": str(0)*1024}
) for i in range(n)]) ) for i in range(n)])
for response in responses: for response in responses:

View file

@ -816,6 +816,11 @@ struct llama_server_context {
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
if (slot.params.cache_prompt && slot.ga_n != 1) {
LOG_WARNING("cache_prompt is not supported with group-attention", {});
slot.params.cache_prompt = false;
}
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ? // Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", { LOG_WARNING("Max tokens to predict exceeds server configuration", {
@ -1769,6 +1774,8 @@ struct llama_server_context {
slot.n_prompt_tokens_processed = slot.n_prompt_tokens; slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
} else { } else {
GGML_ASSERT(slot.ga_n == 1);
// push the prompt into the sampling context (do not apply grammar) // push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens) { for (auto & token : prompt_tokens) {
llama_sampling_accept(slot.ctx_sampling, ctx, token, false); llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
@ -1783,34 +1790,17 @@ struct llama_server_context {
} }
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past; slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
if (slot.ga_n != 1) {
int ga_i = 0;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
int32_t slot_npast = 0;
for (int k = 0; k < slot.n_past; ++k) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
slot_npast++;
}
slot.n_past_se = slot_npast;
slot.ga_i = ga_i;
}
LOG_INFO("slot progression", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
} }
LOG_INFO("slot progression", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
slot.cache_tokens = prompt_tokens; slot.cache_tokens = prompt_tokens;
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
@ -1841,15 +1831,13 @@ struct llama_server_context {
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
}); });
std::vector<llama_token> prefix_tokens = prompt_tokens;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int32_t ga_i = slot.ga_i; int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n; int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w; int32_t ga_w = slot.ga_w;
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
if (slot.ga_n != 1) { if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) { while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1); const int bd = (ga_w/ga_n)*(ga_n - 1);
@ -1858,7 +1846,7 @@ struct llama_server_context {
} }
} }
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++; slot_npast++;
} }