From 05350f286289b1b95e6712af9fca36e8b38e0e62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 19:05:36 +0200 Subject: [PATCH] server : rever n_past_se changes --- examples/server/server.cpp | 80 ++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7ff3622f9..21bdce8ed 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -188,6 +188,8 @@ struct llama_client_slot int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width + int32_t n_past_se = 0; // self-extend + // multimodal std::vector images; @@ -217,6 +219,7 @@ struct llama_client_slot sent_token_probs_index = 0; infill = false; ga_i = 0; + n_past_se = 0; generated_token_probs.clear(); @@ -1293,7 +1296,8 @@ struct llama_server_context for (llama_client_slot &slot : slots) { slot.cache_tokens.clear(); - slot.n_past = 0; + slot.n_past = 0; + slot.n_past_se = 0; } } @@ -1427,9 +1431,11 @@ struct llama_server_context slot.i_batch = batch.n_tokens; + const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; } @@ -1526,6 +1532,7 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; + slot.n_past_se = 0; slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } @@ -1540,6 +1547,25 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + if (slot.ga_n != 1) + { + int ga_i = 0; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < slot.n_past; ++k) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + slot_npast++; + } + slot.n_past_se = slot_npast; + slot.ga_i = ga_i; + } + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -1554,6 +1580,10 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if (slot.ga_i > 0) + { + slot.n_past_se--; + } } LOG_VERBOSE("prompt ingested", { @@ -1562,32 +1592,13 @@ struct llama_server_context {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, }); - if (slot.ga_n != 1) - { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; - } - slot.n_past = slot_npast; - slot.ga_i = ga_i; - - LOG_TEE("slot %d : applied self-extend to prompt: %i tokens\n", slot.id, slot.n_past); - } - const bool has_images = process_images(slot); // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; + int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + int32_t ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; @@ -1596,13 +1607,14 @@ struct llama_server_context { if (slot.ga_n != 1) { - while (slot.n_past >= ga_i + ga_w) { + while (slot_npast >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); - slot.n_past -= bd; + slot_npast -= bd; ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + slot_npast++; } if (has_images && !ingest_images(slot, n_batch)) @@ -1638,29 +1650,31 @@ struct llama_server_context if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past >= slot.ga_i + slot.ga_w) + while (slot.n_past_se >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); - slot.n_past -= bd; + slot.n_past_se -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); } + slot.n_past_se += n_tokens; } } + llama_batch batch_view = { n_tokens,