fixes to smartcontextpro
This commit is contained in:
parent
20ef442c2a
commit
338d6c265d
2 changed files with 25 additions and 13 deletions
|
@ -247,7 +247,7 @@ static std::string RemoveBell(const std::string & input) //removes the bell char
|
|||
return word2;
|
||||
}
|
||||
|
||||
static std::string print_tok_vec_str(std::vector<int> &embd)
|
||||
static std::string get_tok_vec_str(std::vector<int> &embd)
|
||||
{
|
||||
std::string tmp = "";
|
||||
for (auto id : embd)
|
||||
|
@ -257,6 +257,10 @@ static std::string print_tok_vec_str(std::vector<int> &embd)
|
|||
::utreplace(tmp, "\n", "\\n");
|
||||
return tmp;
|
||||
}
|
||||
static void print_tok_vec_str(std::vector<int> &vec)
|
||||
{
|
||||
printf("\n%s", get_tok_vec_str(vec).c_str());
|
||||
}
|
||||
|
||||
|
||||
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
|
||||
|
@ -583,7 +587,7 @@ static void load_grammar(const std::string & gammarstr)
|
|||
|
||||
//given an old GGUF context and a new context that has some middle portion removed,
|
||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt)
|
||||
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
|
||||
{
|
||||
//scan from start old and new ctx, until first mismatch found, save as p0
|
||||
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
||||
|
@ -591,8 +595,8 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_t
|
|||
//if passed, save beginning of LCQ from old ctx as p1
|
||||
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
|
||||
|
||||
const int ShortfallThreshold = 256; //dont trigger shifting if the distance between trimstart and currhead < this
|
||||
const int SlackAllowance = 64; //in case the end text is slightly modified, be forgiving
|
||||
const int ShortfallThreshold = 200; //dont trigger shifting if the distance between trimstart and currhead < this
|
||||
const int SlackAllowance = 50; //in case the end text is slightly modified, be forgiving
|
||||
|
||||
int trimstart = 0;
|
||||
int new_tokens_len = new_context_tokens.size();
|
||||
|
@ -621,7 +625,7 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_t
|
|||
}
|
||||
|
||||
//at least this many tokens need to match, otherwise don't bother trimming
|
||||
const int LCSTokThreshold = std::max((new_tokens_len - trimstart) - (genamt+SlackAllowance), ShortfallThreshold-SlackAllowance);
|
||||
const int LCSTokThreshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+SlackAllowance), (int)(nctx*0.55)), ShortfallThreshold-SlackAllowance);
|
||||
|
||||
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
|
||||
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
|
||||
|
@ -956,7 +960,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
//determine mem per token
|
||||
std::vector<int> tmp = {1, 2, 3, 4};
|
||||
auto er = llama_eval(llama_ctx_v4, tmp.data(), tmp.size(), 0);
|
||||
llama_kv_cache_tokens_rm(llama_ctx_v4, -1, -1);
|
||||
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
||||
if(er!=0)
|
||||
{
|
||||
printf("\nLLAMA EVAL returned nonzero!\n");
|
||||
|
@ -1459,12 +1464,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
else
|
||||
{
|
||||
bool triggersc = useSmartContext;
|
||||
if(useSmartContext && file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
{
|
||||
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length);
|
||||
triggersc = false;
|
||||
if(useSmartContext)
|
||||
{
|
||||
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
|
||||
triggersc = false;
|
||||
}
|
||||
}
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
|
||||
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
{
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||
}
|
||||
}
|
||||
|
||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||
|
@ -1603,9 +1615,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
std::string outstr = "";
|
||||
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
|
||||
outstr += print_tok_vec_str(embd_inp);
|
||||
outstr += get_tok_vec_str(embd_inp);
|
||||
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
|
||||
outstr += print_tok_vec_str(current_context_tokens);
|
||||
outstr += get_tok_vec_str(current_context_tokens);
|
||||
printf("%s\n\n", RemoveBell(outstr).c_str());
|
||||
}
|
||||
|
||||
|
@ -1636,7 +1648,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
{
|
||||
evalres = (llama_eval(llama_ctx_v4, embd.data(), embdsize, n_past)==0);
|
||||
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize, n_past, 0))==0);
|
||||
}
|
||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
|
|
|
@ -9836,7 +9836,7 @@ int llama_eval(
|
|||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past) {
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, 0, n_past, -1);
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
|
||||
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
||||
if (ret < 0) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue