server : hide ctx_sampling->prev behind API (#3696)

This commit is contained in:
Georgi Gerganov 2023-10-22 20:09:01 +03:00
parent 3d6a687f1d
commit 00ae55b388
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1559,7 +1559,8 @@ struct llama_server_context
if (!slot.params.cache_prompt) if (!slot.params.cache_prompt)
{ {
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0); llama_sampling_reset(slot.ctx_sampling);
slot.n_past = 0; slot.n_past = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens; slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} }
@ -1570,16 +1571,17 @@ struct llama_server_context
slot.params.n_keep = slot.num_prompt_tokens; slot.params.n_keep = slot.num_prompt_tokens;
} }
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
//if input prompt is too big, truncate like normal
// if input prompt is too big, truncate it
if (slot.num_prompt_tokens >= slot.n_ctx) if (slot.num_prompt_tokens >= slot.n_ctx)
{ {
// applied bug of #3661
const int n_left = slot.n_ctx - slot.params.n_keep; const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2; const int n_block_size = n_left / 2;
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
// Use half the left-over space in the context for the prompt
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx}, {"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
@ -1588,14 +1590,20 @@ struct llama_server_context
}); });
slot.truncated = true; slot.truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
} }
const size_t ps = slot.num_prompt_tokens;
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0); // push the prompt into the sampling context (do not apply grammar)
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps); for (auto &token : prompt_tokens)
{
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
}
slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
} }