fixes to smartcontextpro

This commit is contained in:
Concedo 2023-10-29 10:42:37 +08:00
parent 20ef442c2a
commit 338d6c265d
2 changed files with 25 additions and 13 deletions

View file

@ -247,7 +247,7 @@ static std::string RemoveBell(const std::string & input) //removes the bell char
return word2;
}
static std::string print_tok_vec_str(std::vector<int> &embd)
static std::string get_tok_vec_str(std::vector<int> &embd)
{
std::string tmp = "";
for (auto id : embd)
@ -257,6 +257,10 @@ static std::string print_tok_vec_str(std::vector<int> &embd)
::utreplace(tmp, "\n", "\\n");
return tmp;
}
static void print_tok_vec_str(std::vector<int> &vec)
{
printf("\n%s", get_tok_vec_str(vec).c_str());
}
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
@ -583,7 +587,7 @@ static void load_grammar(const std::string & gammarstr)
//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt)
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
{
//scan from start old and new ctx, until first mismatch found, save as p0
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
@ -591,8 +595,8 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
//if passed, save beginning of LCQ from old ctx as p1
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
const int ShortfallThreshold = 256; //dont trigger shifting if the distance between trimstart and currhead < this
const int SlackAllowance = 64; //in case the end text is slightly modified, be forgiving
const int ShortfallThreshold = 200; //dont trigger shifting if the distance between trimstart and currhead < this
const int SlackAllowance = 50; //in case the end text is slightly modified, be forgiving
int trimstart = 0;
int new_tokens_len = new_context_tokens.size();
@ -621,7 +625,7 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
}
//at least this many tokens need to match, otherwise don't bother trimming
const int LCSTokThreshold = std::max((new_tokens_len - trimstart) - (genamt+SlackAllowance), ShortfallThreshold-SlackAllowance);
const int LCSTokThreshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+SlackAllowance), (int)(nctx*0.55)), ShortfallThreshold-SlackAllowance);
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
@ -956,7 +960,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
//determine mem per token
std::vector<int> tmp = {1, 2, 3, 4};
auto er = llama_eval(llama_ctx_v4, tmp.data(), tmp.size(), 0);
llama_kv_cache_tokens_rm(llama_ctx_v4, -1, -1);
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
if(er!=0)
{
printf("\nLLAMA EVAL returned nonzero!\n");
@ -1459,12 +1464,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
else
{
bool triggersc = useSmartContext;
if(useSmartContext && file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length);
triggersc = false;
if(useSmartContext)
{
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
triggersc = false;
}
}
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
}
}
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
@ -1603,9 +1615,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
std::string outstr = "";
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
outstr += print_tok_vec_str(embd_inp);
outstr += get_tok_vec_str(embd_inp);
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
outstr += print_tok_vec_str(current_context_tokens);
outstr += get_tok_vec_str(current_context_tokens);
printf("%s\n\n", RemoveBell(outstr).c_str());
}
@ -1636,7 +1648,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
evalres = (llama_eval(llama_ctx_v4, embd.data(), embdsize, n_past)==0);
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize, n_past, 0))==0);
}
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{

View file

@ -9836,7 +9836,7 @@ int llama_eval(
llama_token * tokens,
int32_t n_tokens,
int n_past) {
llama_kv_cache_seq_rm(ctx->kv_self, 0, n_past, -1);
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
if (ret < 0) {