diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index a859e9d55..e1ad0d75f 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -570,6 +570,33 @@ static void load_grammar(const std::string & gammarstr) } } +//given an old GGUF context and a new context that has some middle portion removed, +//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action +void PurgeMissingTokens(std::vector ¤t_context_tokens, std::vector &new_context_tokens) +{ + //scan from start old and new ctx, until first mismatch found, save as p0 + //check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens + //test: longest common subseq (LCQ) MUST start within 0 tokens from end of memory, otherwise purge fails + //if passed, save beginning of LCQ from old ctx as p1 + //remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal + + // int trimstart = 0; + + // const int n_keep = 0; + // const int n_left = n_past - n_keep - 1; + // const int n_discard = n_left/2; + + // printf("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + // n_past, n_left, nctx, n_keep, n_discard); + + // llama_kv_cache_seq_rm (llama_ctx_v4, 0, n_keep + 1 , n_keep + n_discard + 1); + // llama_kv_cache_seq_shift(llama_ctx_v4, 0, n_keep + 1 + n_discard, n_past, -n_discard); + // n_past -= n_discard; + + // printf("after swap: n_past = %d\n", n_past); + +} + ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta) { ggml_time_init(); @@ -1371,6 +1398,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); n_past = 0; + PurgeMissingTokens(current_context_tokens, embd_inp); + if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);