fixes to smartcontextpro
This commit is contained in:
parent
20ef442c2a
commit
338d6c265d
2 changed files with 25 additions and 13 deletions
|
@ -247,7 +247,7 @@ static std::string RemoveBell(const std::string & input) //removes the bell char
|
||||||
return word2;
|
return word2;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string print_tok_vec_str(std::vector<int> &embd)
|
static std::string get_tok_vec_str(std::vector<int> &embd)
|
||||||
{
|
{
|
||||||
std::string tmp = "";
|
std::string tmp = "";
|
||||||
for (auto id : embd)
|
for (auto id : embd)
|
||||||
|
@ -257,6 +257,10 @@ static std::string print_tok_vec_str(std::vector<int> &embd)
|
||||||
::utreplace(tmp, "\n", "\\n");
|
::utreplace(tmp, "\n", "\\n");
|
||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
static void print_tok_vec_str(std::vector<int> &vec)
|
||||||
|
{
|
||||||
|
printf("\n%s", get_tok_vec_str(vec).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
|
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
|
||||||
|
@ -583,7 +587,7 @@ static void load_grammar(const std::string & gammarstr)
|
||||||
|
|
||||||
//given an old GGUF context and a new context that has some middle portion removed,
|
//given an old GGUF context and a new context that has some middle portion removed,
|
||||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||||
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt)
|
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
|
||||||
{
|
{
|
||||||
//scan from start old and new ctx, until first mismatch found, save as p0
|
//scan from start old and new ctx, until first mismatch found, save as p0
|
||||||
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
||||||
|
@ -591,8 +595,8 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_t
|
||||||
//if passed, save beginning of LCQ from old ctx as p1
|
//if passed, save beginning of LCQ from old ctx as p1
|
||||||
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
|
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
|
||||||
|
|
||||||
const int ShortfallThreshold = 256; //dont trigger shifting if the distance between trimstart and currhead < this
|
const int ShortfallThreshold = 200; //dont trigger shifting if the distance between trimstart and currhead < this
|
||||||
const int SlackAllowance = 64; //in case the end text is slightly modified, be forgiving
|
const int SlackAllowance = 50; //in case the end text is slightly modified, be forgiving
|
||||||
|
|
||||||
int trimstart = 0;
|
int trimstart = 0;
|
||||||
int new_tokens_len = new_context_tokens.size();
|
int new_tokens_len = new_context_tokens.size();
|
||||||
|
@ -621,7 +625,7 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_t
|
||||||
}
|
}
|
||||||
|
|
||||||
//at least this many tokens need to match, otherwise don't bother trimming
|
//at least this many tokens need to match, otherwise don't bother trimming
|
||||||
const int LCSTokThreshold = std::max((new_tokens_len - trimstart) - (genamt+SlackAllowance), ShortfallThreshold-SlackAllowance);
|
const int LCSTokThreshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+SlackAllowance), (int)(nctx*0.55)), ShortfallThreshold-SlackAllowance);
|
||||||
|
|
||||||
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
|
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
|
||||||
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
|
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
|
||||||
|
@ -956,7 +960,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
|
|
||||||
//determine mem per token
|
//determine mem per token
|
||||||
std::vector<int> tmp = {1, 2, 3, 4};
|
std::vector<int> tmp = {1, 2, 3, 4};
|
||||||
auto er = llama_eval(llama_ctx_v4, tmp.data(), tmp.size(), 0);
|
llama_kv_cache_tokens_rm(llama_ctx_v4, -1, -1);
|
||||||
|
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
||||||
if(er!=0)
|
if(er!=0)
|
||||||
{
|
{
|
||||||
printf("\nLLAMA EVAL returned nonzero!\n");
|
printf("\nLLAMA EVAL returned nonzero!\n");
|
||||||
|
@ -1459,12 +1464,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
bool triggersc = useSmartContext;
|
bool triggersc = useSmartContext;
|
||||||
if(useSmartContext && file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||||
{
|
{
|
||||||
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length);
|
if(useSmartContext)
|
||||||
|
{
|
||||||
|
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
|
||||||
triggersc = false;
|
triggersc = false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
|
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
|
||||||
|
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||||
|
{
|
||||||
|
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||||
|
@ -1603,9 +1615,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
std::string outstr = "";
|
std::string outstr = "";
|
||||||
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
|
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
|
||||||
outstr += print_tok_vec_str(embd_inp);
|
outstr += get_tok_vec_str(embd_inp);
|
||||||
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
|
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
|
||||||
outstr += print_tok_vec_str(current_context_tokens);
|
outstr += get_tok_vec_str(current_context_tokens);
|
||||||
printf("%s\n\n", RemoveBell(outstr).c_str());
|
printf("%s\n\n", RemoveBell(outstr).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1636,7 +1648,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||||
{
|
{
|
||||||
evalres = (llama_eval(llama_ctx_v4, embd.data(), embdsize, n_past)==0);
|
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize, n_past, 0))==0);
|
||||||
}
|
}
|
||||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
|
|
|
@ -9836,7 +9836,7 @@ int llama_eval(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int n_past) {
|
int n_past) {
|
||||||
llama_kv_cache_seq_rm(ctx->kv_self, 0, n_past, -1);
|
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue