context shift fixed
This commit is contained in:
parent
2d9f11db28
commit
d7eca255d7
1 changed files with 39 additions and 20 deletions
|
@ -757,8 +757,8 @@ struct llama_server_context
|
|||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||
slot.sent_count += result.text_to_send.size();
|
||||
// add the token to slot queue and cache
|
||||
slot.addTokenString(result);
|
||||
}
|
||||
slot.addTokenString(result);
|
||||
if (slot.multibyte_pending > 0)
|
||||
{
|
||||
slot.multibyte_pending -= token_str.size();
|
||||
|
@ -925,8 +925,8 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
// context shift takes effect only when there is a single slot
|
||||
if(slots.size() == 1) {
|
||||
llama_client_slot slot = slots[0];
|
||||
if(params.n_parallel == 1) {
|
||||
llama_client_slot &slot = slots[0];
|
||||
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
|
||||
{
|
||||
// Shift context
|
||||
|
@ -1028,22 +1028,16 @@ struct llama_server_context
|
|||
|
||||
slot.num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
slot.n_past = slot.params.cache_prompt ? common_part(slot.cache_tokens, prompt_tokens) : 0;
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.num_prompt_tokens) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
printf("we have to evaluate at least 1 token to generate logits\n");
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||
|
||||
if(!slot.params.cache_prompt) {
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
|
||||
slot.n_past = 0;
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
||||
} else {
|
||||
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
if (params.n_keep < 0 && params.n_parallel == 1)
|
||||
{
|
||||
params.n_keep = (int)slot.num_prompt_tokens;
|
||||
}
|
||||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||
//if input prompt is too big, truncate like normal
|
||||
if (slot.num_prompt_tokens >= (size_t)n_ctx)
|
||||
{
|
||||
|
@ -1059,14 +1053,26 @@ struct llama_server_context
|
|||
});
|
||||
slot.truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
slot.num_prompt_tokens = prompt_tokens.size();
|
||||
}
|
||||
const size_t ps = slot.num_prompt_tokens;
|
||||
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.num_prompt_tokens) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
printf("we have to evaluate at least 1 token to generate logits\n");
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
||||
|
@ -1185,7 +1191,7 @@ struct llama_server_context
|
|||
}
|
||||
}
|
||||
|
||||
if(kv_cache_free < 0) {
|
||||
if(kv_cache_free < 0 && params.n_parallel > 1) {
|
||||
LOG_TEE("\nError: kv cache is full, increase context size.");
|
||||
return false;
|
||||
}
|
||||
|
@ -1581,6 +1587,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
}
|
||||
|
||||
static void slot_print_timings(struct llama_client_slot * slot) {
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, slot->t_prompt_processing, slot->num_prompt_tokens_processed, slot->t_prompt_processing / slot->num_prompt_tokens_processed, 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed);
|
||||
LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, slot->t_token_generation, slot->n_decoded, slot->t_token_generation / slot->n_decoded, 1e3 / slot->t_token_generation * slot->n_decoded);
|
||||
LOG_TEE("%s: total time = %10.2f ms\n", __func__, slot->t_prompt_processing + slot->t_token_generation);
|
||||
}
|
||||
|
||||
static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot)
|
||||
{
|
||||
const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx));
|
||||
|
@ -1606,7 +1621,7 @@ static json format_generation_settings(llama_server_context &llama, llama_client
|
|||
{"penalize_nl", slot->sparams.penalize_nl},
|
||||
{"stop", slot->params.antiprompt},
|
||||
{"n_predict", slot->params.n_predict},
|
||||
// {"n_keep", slot.params.n_keep},
|
||||
{"n_keep", llama.params.n_keep},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", slot->params.stream},
|
||||
{"logit_bias", slot->sparams.logit_bias},
|
||||
|
@ -1730,7 +1745,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
|
|||
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||
llama.params.n_keep = json_value(body, "n_keep", -1);
|
||||
llama.params.n_keep = json_value(body, "n_keep", 0);
|
||||
slot->params.seed = json_value(body, "seed", default_params.seed);
|
||||
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||
|
@ -2089,6 +2104,7 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
const json data = format_final_response(llama, slot, completion_text, probs);
|
||||
slot_print_timings(slot);
|
||||
slot->release();
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
|
@ -2131,6 +2147,7 @@ int main(int argc, char **argv)
|
|||
slot->generated_token_probs.begin(),
|
||||
slot->generated_token_probs.begin() + sent_token_probs_index)
|
||||
);
|
||||
slot_print_timings(slot);
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
|
@ -2197,6 +2214,7 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
const json data = format_final_response(llama, slot, completion_text, probs);
|
||||
slot_print_timings(slot);
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
} else {
|
||||
|
@ -2238,6 +2256,7 @@ int main(int argc, char **argv)
|
|||
slot->generated_token_probs.begin(),
|
||||
slot->generated_token_probs.begin() + sent_token_probs_index)
|
||||
);
|
||||
slot_print_timings(slot);
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue