From 8772d3ee638f11e575a5a38b82bb58b05b364aef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 15:52:18 +0200 Subject: [PATCH] server : take system_tokens into account --- examples/server/server.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 913e6098e..17039e8d9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1225,7 +1225,7 @@ struct llama_server_context std::vector append_tokens = tokenize(json_prompt, false); // has next image for (int i = 0; i < (int) append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); + llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; } } @@ -1376,12 +1376,12 @@ struct llama_server_context if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) { // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1426,6 +1426,8 @@ struct llama_server_context slot.i_batch = batch.n_tokens; + // TODO: we always have to take into account the "system_tokens" + // this is not great and needs to be improved somehow llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -1478,8 +1480,8 @@ struct llama_server_context prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(model)); prompt_tokens = prefix_tokens; }