revamped smart context for llama models
This commit is contained in:
parent
c2f675133d
commit
15f525c580
5 changed files with 90 additions and 34 deletions
2
Makefile
2
Makefile
|
@ -429,7 +429,7 @@ gpttype_adapter_cublas.o: $(GPTTYPE_ADAPTER)
|
|||
clean:
|
||||
rm -vf *.o main quantize_llama quantize_gpt2 quantize_gptj quantize_neox quantize_mpt quantize-stats perplexity embedding benchmark-matmult save-load-state gguf gguf.exe main.exe quantize_llama.exe quantize_gptj.exe quantize_gpt2.exe quantize_neox.exe quantize_mpt.exe koboldcpp_default.dll koboldcpp_openblas.dll koboldcpp_failsafe.dll koboldcpp_noavx2.dll koboldcpp_clblast.dll koboldcpp_cublas.dll koboldcpp_hipblas.dll koboldcpp_default.so koboldcpp_openblas.so koboldcpp_failsafe.so koboldcpp_noavx2.so koboldcpp_clblast.so koboldcpp_cublas.so koboldcpp_hipblas.so
|
||||
|
||||
main: examples/main/main.cpp build-info.h ggml.o $(KQ1) ggml-alloc.o ggml-backend.o llama.o common.o console.o grammar-parser.o $(OBJS)
|
||||
main: examples/main/main.cpp common/sampling.cpp build-info.h ggml.o $(KQ1) ggml-alloc.o ggml-backend.o llama.o common.o console.o grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
@echo
|
||||
@echo '==== Run ./main -h for help. ===='
|
||||
|
|
|
@ -247,6 +247,17 @@ static std::string RemoveBell(const std::string & input) //removes the bell char
|
|||
return word2;
|
||||
}
|
||||
|
||||
static std::string print_tok_vec_str(std::vector<int> &embd)
|
||||
{
|
||||
std::string tmp = "";
|
||||
for (auto id : embd)
|
||||
{
|
||||
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||
}
|
||||
::utreplace(tmp, "\n", "\\n");
|
||||
return tmp;
|
||||
}
|
||||
|
||||
|
||||
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
|
||||
{
|
||||
|
@ -572,7 +583,7 @@ static void load_grammar(const std::string & gammarstr)
|
|||
|
||||
//given an old GGUF context and a new context that has some middle portion removed,
|
||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||
void PurgeMissingTokens(std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens)
|
||||
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt)
|
||||
{
|
||||
//scan from start old and new ctx, until first mismatch found, save as p0
|
||||
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
||||
|
@ -580,20 +591,63 @@ void PurgeMissingTokens(std::vector<int> ¤t_context_tokens, std::vector<in
|
|||
//if passed, save beginning of LCQ from old ctx as p1
|
||||
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
|
||||
|
||||
// int trimstart = 0;
|
||||
const int ShortfallThreshold = 256; //dont trigger shifting if the distance between trimstart and currhead < this
|
||||
const int SlackAllowance = 32; //in case the end text is slightly modified, be forgiving
|
||||
|
||||
// const int n_keep = 0;
|
||||
// const int n_left = n_past - n_keep - 1;
|
||||
// const int n_discard = n_left/2;
|
||||
int trimstart = 0;
|
||||
int new_tokens_len = new_context_tokens.size();
|
||||
bool purgeneeded = true;
|
||||
|
||||
// printf("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
// n_past, n_left, nctx, n_keep, n_discard);
|
||||
for (int i = 0; i < current_context_tokens.size(); ++i)
|
||||
{
|
||||
if (current_context_tokens[i] == new_context_tokens[i])
|
||||
{
|
||||
trimstart += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
if ((i + 2) >= new_tokens_len)
|
||||
{
|
||||
purgeneeded = false;
|
||||
break; //no surgery required
|
||||
}
|
||||
}
|
||||
|
||||
// llama_kv_cache_seq_rm (llama_ctx_v4, 0, n_keep + 1 , n_keep + n_discard + 1);
|
||||
// llama_kv_cache_seq_shift(llama_ctx_v4, 0, n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
// n_past -= n_discard;
|
||||
if(!purgeneeded || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < ShortfallThreshold)
|
||||
{
|
||||
return; //no purge is needed
|
||||
}
|
||||
|
||||
// printf("after swap: n_past = %d\n", n_past);
|
||||
//at least this many tokens need to match, otherwise don't bother trimming
|
||||
const int LCQTokThreshold = std::max((new_tokens_len - trimstart) - (genamt+SlackAllowance), ShortfallThreshold-SlackAllowance);
|
||||
|
||||
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
|
||||
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
|
||||
|
||||
auto shared = LongestCommonSubseq(curr_ctx_without_memory, new_ctx_without_memory);
|
||||
|
||||
if (shared.size() > LCQTokThreshold && ArrStartWith(new_ctx_without_memory, shared)) // enough tokens in common
|
||||
{
|
||||
int found = ArrFindIndexOf(current_context_tokens,shared);
|
||||
if(found>=0 && found > trimstart)
|
||||
{
|
||||
//extract the unwanted tokens out from context and KV
|
||||
int diff = found - trimstart;
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, trimstart + 1, trimstart + diff + 1);
|
||||
llama_kv_cache_seq_shift(llama_ctx_v4, 0, trimstart + diff + 1, -1, -diff);
|
||||
|
||||
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
|
||||
{
|
||||
current_context_tokens[i - diff] = current_context_tokens[i];
|
||||
}
|
||||
|
||||
printf("\n[Smart Context Pro: Erased %d tokens at position %d]", diff, trimstart+1);
|
||||
|
||||
current_context_tokens.resize(current_context_tokens.size() - diff - 1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1398,15 +1452,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
n_past = 0;
|
||||
|
||||
PurgeMissingTokens(current_context_tokens, embd_inp);
|
||||
|
||||
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext, false);
|
||||
bool triggersc = useSmartContext;
|
||||
if(useSmartContext && file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
{
|
||||
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length);
|
||||
triggersc = false;
|
||||
}
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
|
||||
}
|
||||
|
||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||
|
@ -1545,23 +1603,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
std::string outstr = "";
|
||||
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
|
||||
|
||||
std::string tmp = "";
|
||||
for (auto id : embd_inp)
|
||||
{
|
||||
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||
}
|
||||
::utreplace(tmp, "\n", "\\n");
|
||||
outstr += tmp;
|
||||
|
||||
outstr += print_tok_vec_str(embd_inp);
|
||||
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
|
||||
tmp = "";
|
||||
for (auto id : current_context_tokens)
|
||||
{
|
||||
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||
}
|
||||
::utreplace(tmp, "\n", "\\n");
|
||||
outstr += tmp;
|
||||
outstr += print_tok_vec_str(current_context_tokens);
|
||||
printf("%s\n\n", RemoveBell(outstr).c_str());
|
||||
}
|
||||
|
||||
|
|
|
@ -926,7 +926,7 @@ def show_new_gui():
|
|||
# slider data
|
||||
blasbatchsize_values = ["-1", "32", "64", "128", "256", "512", "1024", "2048"]
|
||||
blasbatchsize_text = ["Don't Batch BLAS","32","64","128","256","512","1024","2048"]
|
||||
contextsize_text = ["512", "1024", "2048", "3072", "4096", "6144", "8192", "12288", "16384", "24576", "32768", "65536"]
|
||||
contextsize_text = ["256", "512", "1024", "2048", "3072", "4096", "6144", "8192", "12288", "16384", "24576", "32768", "65536"]
|
||||
runopts = [opt for lib, opt in lib_option_pairs if file_exists(lib)]
|
||||
antirunopts = [opt.replace("Use ", "") for lib, opt in lib_option_pairs if not (opt in runopts)]
|
||||
if not any(runopts):
|
||||
|
@ -1914,7 +1914,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
|
||||
parser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0)
|
||||
parser.add_argument("--highpriority", help="Experimental flag. If set, increases the process CPU priority, potentially speeding up generation. Use caution.", action='store_true')
|
||||
parser.add_argument("--contextsize", help="Controls the memory allocated for maximum context size, only change if you need more RAM for big contexts. (default 2048)", type=int,choices=[512,1024,2048,3072,4096,6144,8192,12288,16384,24576,32768,65536], default=2048)
|
||||
parser.add_argument("--contextsize", help="Controls the memory allocated for maximum context size, only change if you need more RAM for big contexts. (default 2048)", type=int,choices=[256, 512,1024,2048,3072,4096,6144,8192,12288,16384,24576,32768,65536], default=2048)
|
||||
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512). Setting it to -1 disables BLAS mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,32,64,128,256,512,1024,2048], default=512)
|
||||
parser.add_argument("--ropeconfig", help="If set, uses customized RoPE scaling from configured frequency scale and frequency base (e.g. --ropeconfig 0.25 10000). Otherwise, uses NTK-Aware scaling set automatically based on context size. For linear rope, simply set the freq-scale and ignore the freq-base",metavar=('[rope-freq-scale]', '[rope-freq-base]'), default=[0.0, 10000.0], type=float, nargs='+')
|
||||
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||
|
|
12
llama.cpp
12
llama.cpp
|
@ -9819,12 +9819,22 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
return true;
|
||||
}
|
||||
|
||||
void printcache(struct llama_context * ctx)
|
||||
{
|
||||
struct llama_kv_cache & cache = ctx->kv_self;
|
||||
std::string vals = "\n";
|
||||
for (int32_t i = 0; i < cache.size; ++i) {
|
||||
vals += std::to_string(i) + "= pos:" + std::to_string(cache.cells[i].pos) + " delta:" + std::to_string(cache.cells[i].delta) +"\n";
|
||||
}
|
||||
printf("%s",vals.c_str());
|
||||
}
|
||||
|
||||
int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, 0, n_past, -1);
|
||||
|
||||
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
||||
if (ret < 0) {
|
||||
|
|
2
llama.h
2
llama.h
|
@ -350,6 +350,8 @@ extern "C" {
|
|||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
LLAMA_API void printcache(struct llama_context * ctx);
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue