diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2760aea8f..28a1e1695 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -100,10 +100,11 @@ struct slot_params { bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - uint32_t seed = -1; // RNG seed - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict + uint32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_truncate = 0; + int32_t n_predict = -1; // new tokens to predict std::vector antiprompt; @@ -170,6 +171,7 @@ struct server_slot { bool infill = false; bool embedding = false; bool has_next_token = true; + bool shifted = false; bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -190,7 +192,7 @@ struct server_slot { int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width - int32_t n_past_se = 0; // self-extend + int32_t n_past_se = 0; // self-extend // stats size_t n_sent_text = 0; // number of sent text character @@ -205,6 +207,7 @@ struct server_slot { void reset() { n_prompt_tokens = 0; generated_text = ""; + shifted = false; truncated = false; stopped_eos = false; stopped_word = false; @@ -701,7 +704,7 @@ struct server_context { return res > 0; } - + //MARK: Init void init() { const int32_t n_ctx_slot = n_ctx / params.n_parallel; @@ -854,6 +857,7 @@ struct server_context { slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate); slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); @@ -1282,6 +1286,7 @@ struct server_context { {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, + {"n_truncate", slot.params.n_truncate}, {"ignore_eos", ignore_eos}, {"stream", slot.params.stream}, {"logit_bias", slot.sparams.logit_bias}, @@ -1369,6 +1374,7 @@ struct server_context { {"tokens_evaluated", slot.n_prompt_tokens}, {"generation_settings", get_formated_generation(slot)}, {"prompt", slot.prompt}, + {"shifted", slot.shifted}, {"truncated", slot.truncated}, {"stopped_eos", slot.stopped_eos}, {"stopped_word", slot.stopped_word}, @@ -1784,6 +1790,7 @@ struct server_context { {"n_past", slot.n_past}, {"n_system_tokens", system_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()}, + {"shifted", slot.shifted}, {"truncated", slot.truncated} }); @@ -1857,7 +1864,7 @@ struct server_context { slot.n_past -= n_discard; - slot.truncated = true; + slot.shifted = true; } } } @@ -1892,6 +1899,7 @@ struct server_context { {"n_past", slot.n_past}, {"n_system_tokens", system_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()}, + {"shifted", slot.shifted}, {"truncated", slot.truncated} }); } @@ -1988,9 +1996,11 @@ struct server_context { // if input prompt is too big, truncate it (if group attention self-extend is disabled) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; + const int n_truncate = slot.params.n_truncate; const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size; + const int erased_tokens = n_truncate ? n_truncate : half_ctx; std::vector new_tokens( prompt_tokens.begin(), @@ -1998,7 +2008,7 @@ struct server_context { new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.begin() + slot.params.n_keep + erased_tokens, prompt_tokens.end()); prompt_tokens = std::move(new_tokens); @@ -2006,6 +2016,17 @@ struct server_context { slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); + if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) { + LOG_ERROR("prompt - n_truncate > n_ctx", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_truncate", n_truncate}, + }); + } + LOG_VERBOSE("input truncated", { {"id_slot", slot.id}, {"id_task", slot.id_task}, @@ -2016,6 +2037,29 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); + if (n_truncate && slot.params.cache_prompt) { + const int n_keep = slot.params.n_keep + add_bos_token; + for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i]; + } + + const int new_cache_size = slot.cache_tokens.size() - n_truncate; + if (new_cache_size >= 0) { + slot.cache_tokens.resize(new_cache_size); + + LOG_VERBOSE("cache tokens shift", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_truncate", n_truncate}, + {"new_cache_size", new_cache_size}, + {"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())}, + }); + } // else somebody trying to use n_truncate w/o previous cache + } + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } @@ -2030,6 +2074,12 @@ struct server_context { // reuse any previously computed tokens that are common with the new prompt slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + LOG_INFO("[cached_tokens, prompt_tokens]", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "common_part", slot.n_past} + }); + // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); @@ -2060,6 +2110,30 @@ struct server_context { } } + // shift KV cache if needed + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_truncate = slot.params.n_truncate; + if (n_truncate && slot.params.cache_prompt) { + llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate); + + LOG_INFO("kv cache rm", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "p0", n_keep }, + { "p1", n_keep + n_truncate } + }); + + llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate); + + LOG_INFO("kv cache add", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "p0", n_keep + n_truncate }, + { "p1", slot.n_past }, + { "delta", -n_truncate } + }); + } + // keep only the common part int p0 = (int) system_tokens.size() + slot.n_past; if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {