server : take system_tokens into account

This commit is contained in:
Georgi Gerganov 2024-01-29 15:52:18 +02:00
parent 51bb7f0eef
commit 8772d3ee63
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -1225,7 +1225,7 @@ struct llama_server_context
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < (int) append_tokens.size(); ++i)
{
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
slot.n_past += 1;
}
}
@ -1376,12 +1376,12 @@ struct llama_server_context
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{
// Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2;
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard);
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
@ -1426,6 +1426,8 @@ struct llama_server_context
slot.i_batch = batch.n_tokens;
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
slot.n_past += 1;
@ -1478,8 +1480,8 @@ struct llama_server_context
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}