This commit is contained in:
Concedo 2023-10-26 21:58:51 +08:00
parent 5db89b90b7
commit 0f46534866

View file

@ -570,6 +570,33 @@ static void load_grammar(const std::string & gammarstr)
} }
} }
//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens)
{
//scan from start old and new ctx, until first mismatch found, save as p0
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
//test: longest common subseq (LCQ) MUST start within 0 tokens from end of memory, otherwise purge fails
//if passed, save beginning of LCQ from old ctx as p1
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
// int trimstart = 0;
// const int n_keep = 0;
// const int n_left = n_past - n_keep - 1;
// const int n_discard = n_left/2;
// printf("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
// n_past, n_left, nctx, n_keep, n_discard);
// llama_kv_cache_seq_rm (llama_ctx_v4, 0, n_keep + 1 , n_keep + n_discard + 1);
// llama_kv_cache_seq_shift(llama_ctx_v4, 0, n_keep + 1 + n_discard, n_past, -n_discard);
// n_past -= n_discard;
// printf("after swap: n_past = %d\n", n_past);
}
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta) ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta)
{ {
ggml_time_init(); ggml_time_init();
@ -1371,6 +1398,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0; n_past = 0;
PurgeMissingTokens(current_context_tokens, embd_inp);
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{ {
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true); ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);