server : disable cached prompts with self-extend
This commit is contained in:
parent
61b63705dc
commit
aef02b11ec
2 changed files with 19 additions and 31 deletions
|
@ -13,7 +13,7 @@ async def main():
|
|||
model_url = "http://127.0.0.1:6900"
|
||||
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||
url= f"{model_url}/embedding",
|
||||
json= {"content": str(i)*1024}
|
||||
json= {"content": str(0)*1024}
|
||||
) for i in range(n)])
|
||||
|
||||
for response in responses:
|
||||
|
|
|
@ -816,6 +816,11 @@ struct llama_server_context {
|
|||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
||||
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
||||
slot.params.cache_prompt = false;
|
||||
}
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
||||
|
@ -1769,6 +1774,8 @@ struct llama_server_context {
|
|||
|
||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(slot.ga_n == 1);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (auto & token : prompt_tokens) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
|
||||
|
@ -1783,34 +1790,17 @@ struct llama_server_context {
|
|||
}
|
||||
|
||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
|
||||
|
||||
if (slot.ga_n != 1) {
|
||||
int ga_i = 0;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
int32_t slot_npast = 0;
|
||||
for (int k = 0; k < slot.n_past; ++k) {
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
slot_npast -= bd;
|
||||
ga_i += ga_w/ga_n;
|
||||
}
|
||||
slot_npast++;
|
||||
}
|
||||
slot.n_past_se = slot_npast;
|
||||
slot.ga_i = ga_i;
|
||||
}
|
||||
|
||||
LOG_INFO("slot progression", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "n_past", slot.n_past },
|
||||
{ "n_past_se", slot.n_past_se },
|
||||
{ "ga_i", slot.ga_i },
|
||||
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
|
||||
});
|
||||
}
|
||||
|
||||
LOG_INFO("slot progression", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "n_past", slot.n_past },
|
||||
{ "n_past_se", slot.n_past_se },
|
||||
{ "ga_i", slot.ga_i },
|
||||
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
|
||||
});
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
|
@ -1841,15 +1831,13 @@ struct llama_server_context {
|
|||
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
|
||||
});
|
||||
|
||||
std::vector<llama_token> prefix_tokens = prompt_tokens;
|
||||
|
||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
|
||||
int32_t ga_i = slot.ga_i;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
|
||||
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) {
|
||||
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
|
||||
if (slot.ga_n != 1) {
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
|
@ -1858,7 +1846,7 @@ struct llama_server_context {
|
|||
}
|
||||
}
|
||||
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||
|
||||
slot_npast++;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue