context shift feature done

This commit is contained in:
Concedo 2023-10-29 18:21:39 +08:00
parent 338d6c265d
commit 7924592a83
4 changed files with 41 additions and 18 deletions

View file

@ -38,6 +38,7 @@ struct load_model_inputs
const bool use_mmap; const bool use_mmap;
const bool use_mlock; const bool use_mlock;
const bool use_smartcontext; const bool use_smartcontext;
const bool use_contextshift;
const int clblast_info = 0; const int clblast_info = 0;
const int cublas_info = 0; const int cublas_info = 0;
const int blasbatchsize = 512; const int blasbatchsize = 512;

View file

@ -78,6 +78,7 @@ static int n_threads = 4;
static int n_blasthreads = 4; static int n_blasthreads = 4;
static int n_batch = 8; static int n_batch = 8;
static bool useSmartContext = false; static bool useSmartContext = false;
static bool useContextShift = false;
static int blasbatchsize = 512; static int blasbatchsize = 512;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::string modelname; static std::string modelname;
@ -647,7 +648,7 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
current_context_tokens[i - diff] = current_context_tokens[i]; current_context_tokens[i - diff] = current_context_tokens[i];
} }
printf("\n[Smart Context Pro: Erased %d tokens at position %d]", diff, trimstart+1); printf("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart+1);
current_context_tokens.resize(current_context_tokens.size() - diff - 1); current_context_tokens.resize(current_context_tokens.size() - diff - 1);
} }
@ -665,6 +666,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_batch = params.n_batch = inputs.batch_size; n_batch = params.n_batch = inputs.batch_size;
modelname = params.model = inputs.model_filename; modelname = params.model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext; useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift;
debugmode = inputs.debugmode; debugmode = inputs.debugmode;
blasbatchsize = inputs.blasbatchsize; blasbatchsize = inputs.blasbatchsize;
if(blasbatchsize<=0) if(blasbatchsize<=0)
@ -1464,13 +1466,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
else else
{ {
bool triggersc = useSmartContext; bool triggersc = useSmartContext;
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) if(useContextShift && (file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON))
{ {
if(useSmartContext) PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
{ triggersc = false;
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
triggersc = false;
}
} }
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false); ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
@ -1717,7 +1716,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (!evalres) if (!evalres)
{ {
fprintf(stderr, "Failed to predict\n"); fprintf(stderr, "\nFailed to predict! Check your context buffer sizes!\n");
snprintf(output.text, sizeof(output.text), "%s", ""); snprintf(output.text, sizeof(output.text), "%s", "");
output.status = 0; output.status = 0;
generation_finished = true; generation_finished = true;

View file

@ -6,7 +6,7 @@ It requires no dependencies, installation or setup.
Just copy this single static HTML file anywhere and open it in a browser, or from a webserver. Just copy this single static HTML file anywhere and open it in a browser, or from a webserver.
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 87 Current version: 88
-Concedo -Concedo
--> -->
@ -3050,6 +3050,16 @@ Current version: 87
return segmentsA.length - segmentsB.length; return segmentsA.length - segmentsB.length;
} }
function countWords(str) { //simple word counter
if (str == "") { return 0; }
const wordPattern = /[a-zA-Z0-9_]+/g;
const words = str.match(wordPattern);
if (!words) {
return 0;
}
return words.length;
}
function convertMarkdownTableToHtml(t){let hsep = /^[\s]*\|(?:[\s]*[-:]+[-:|\s]*)+\|[\s]*$/gm;let l=/^[\s]*\|(.*)\|[\s]*$/gm,r=t.split(/\r?\n|\r/),e="<table class='tablelines'>";for(let o of r){let hs=o.match(hsep);if(hs){continue;}let d=o.match(l);if(d){let i=d[0].split("|").map(t=>t.trim());e+=`<tr class='tablelines'><td class='tablelines'>${i.join("</td><td class='tablelines'>")}</td></tr>`}}return e+"</table>"} function convertMarkdownTableToHtml(t){let hsep = /^[\s]*\|(?:[\s]*[-:]+[-:|\s]*)+\|[\s]*$/gm;let l=/^[\s]*\|(.*)\|[\s]*$/gm,r=t.split(/\r?\n|\r/),e="<table class='tablelines'>";for(let o of r){let hs=o.match(hsep);if(hs){continue;}let d=o.match(l);if(d){let i=d[0].split("|").map(t=>t.trim());e+=`<tr class='tablelines'><td class='tablelines'>${i.join("</td><td class='tablelines'>")}</td></tr>`}}return e+"</table>"}
//casualwriter casual-markdown, under MIT license //casualwriter casual-markdown, under MIT license
@ -7875,16 +7885,23 @@ Current version: 87
} }
} }
//this is a hack since we dont have a proper tokenizer, but we can estimate 1 token per 3.3 characters let truncated_context = concat_gametext(true, ""); //no need to truncate if memory is empty
let max_allowed_characters = Math.max(1, Math.floor(maxctxlen * 3) - (maxgenamt+8)); truncated_context = truncated_context.replace(/\xA0/g,' '); //replace non breaking space nbsp
//this is a hack since we dont have a proper tokenizer, but we can estimate 1 token per 3 characters
let chars_per_token = 3.0;
//we try to detect attempts at coding which tokenize poorly. This usually happens when the average word length is high.
let avgwordlen = (1.0+truncated_context.length)/(1.0+countWords(truncated_context));
if(avgwordlen>=7.8)
{
chars_per_token = 2.7;
}
if (current_memory == null || current_memory.trim() == "") if (current_memory == null || current_memory.trim() == "")
{ {
//if there is no memory, then we can be a lot of lenient with the character counts since the backend will truncate excess anyway //if there is no memory, then we can be a lot of lenient with the character counts since the backend will truncate excess anyway
max_allowed_characters = Math.floor(maxctxlen * 4.6); chars_per_token = 4.8;
} }
let max_allowed_characters = Math.max(1, Math.floor((maxctxlen-maxgenamt) * chars_per_token) - 8);
let truncated_context = concat_gametext(true, ""); //no need to truncate if memory is empty
truncated_context = truncated_context.replace(/\xA0/g,' '); //replace non breaking space nbsp
//for adventure mode, inject hidden context, even more if there's nothing in memory //for adventure mode, inject hidden context, even more if there's nothing in memory
if (localsettings.opmode == 2 && localsettings.adventure_context_mod) if (localsettings.opmode == 2 && localsettings.adventure_context_mod)

View file

@ -34,6 +34,7 @@ class load_model_inputs(ctypes.Structure):
("use_mmap", ctypes.c_bool), ("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool), ("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool), ("use_smartcontext", ctypes.c_bool),
("use_contextshift", ctypes.c_bool),
("clblast_info", ctypes.c_int), ("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int), ("cublas_info", ctypes.c_int),
("blasbatchsize", ctypes.c_int), ("blasbatchsize", ctypes.c_int),
@ -227,6 +228,7 @@ def load_model(model_filename):
if len(args.lora) > 1: if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8") inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.use_smartcontext = args.smartcontext inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (not args.nocontextshift)
inputs.blasbatchsize = args.blasbatchsize inputs.blasbatchsize = args.blasbatchsize
inputs.forceversion = args.forceversion inputs.forceversion = args.forceversion
inputs.gpulayers = args.gpulayers inputs.gpulayers = args.gpulayers
@ -1045,6 +1047,7 @@ def show_new_gui():
version_var = ctk.StringVar(value="0") version_var = ctk.StringVar(value="0")
tensor_split_str_vars = ctk.StringVar(value="") tensor_split_str_vars = ctk.StringVar(value="")
contextshift = ctk.IntVar(value=1)
smartcontext = ctk.IntVar() smartcontext = ctk.IntVar()
context_var = ctk.IntVar() context_var = ctk.IntVar()
customrope_var = ctk.IntVar() customrope_var = ctk.IntVar()
@ -1142,7 +1145,7 @@ def show_new_gui():
makeslider(quick_tab, "BLAS Batch Size:", blasbatchsize_text, blas_size_var, 0, 7, 12, set=5) makeslider(quick_tab, "BLAS Batch Size:", blasbatchsize_text, blas_size_var, 0, 7, 12, set=5)
# quick boxes # quick boxes
quick_boxes = {"Launch Browser": launchbrowser , "High Priority" : highpriority, "Use SmartContext":smartcontext, "Disable MMAP":disablemmap,} quick_boxes = {"Launch Browser": launchbrowser , "High Priority" : highpriority, "Use SmartContext":smartcontext, "Disable MMAP":disablemmap,"Use ContextShift":contextshift}
for idx, name, in enumerate(quick_boxes): for idx, name, in enumerate(quick_boxes):
makecheckbox(quick_tab, name, quick_boxes[name], int(idx/2) +20, idx%2) makecheckbox(quick_tab, name, quick_boxes[name], int(idx/2) +20, idx%2)
# context size # context size
@ -1194,7 +1197,7 @@ def show_new_gui():
# Tokens Tab # Tokens Tab
tokens_tab = tabcontent["Tokens"] tokens_tab = tabcontent["Tokens"]
# tokens checkboxes # tokens checkboxes
token_boxes = {"Use SmartContext":smartcontext} token_boxes = {"Use SmartContext":smartcontext, "Use ContextShift":contextshift}
for idx, name, in enumerate(token_boxes): for idx, name, in enumerate(token_boxes):
makecheckbox(tokens_tab, name, token_boxes[name], idx + 1) makecheckbox(tokens_tab, name, token_boxes[name], idx + 1)
@ -1273,6 +1276,7 @@ def show_new_gui():
args.highpriority = highpriority.get()==1 args.highpriority = highpriority.get()==1
args.nommap = disablemmap.get()==1 args.nommap = disablemmap.get()==1
args.smartcontext = smartcontext.get()==1 args.smartcontext = smartcontext.get()==1
args.nocontextshift = contextshift.get()==0
args.foreground = keepforeground.get()==1 args.foreground = keepforeground.get()==1
gpuchoiceidx = 0 gpuchoiceidx = 0
@ -1336,6 +1340,7 @@ def show_new_gui():
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0) highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0) disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0) smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
contextshift.set(0 if "nocontextshift" in dict and dict["nocontextshift"] else 1)
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0) keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
if "useclblast" in dict and dict["useclblast"]: if "useclblast" in dict and dict["useclblast"]:
if clblast_option is not None: if clblast_option is not None:
@ -1833,7 +1838,7 @@ def main(launch_args,start_server=True):
modelname = os.path.abspath(args.model_param) modelname = os.path.abspath(args.model_param)
print(args) print(args)
print(f"==========\nLoading model: {modelname} \n[Threads: {args.threads}, BlasThreads: {args.blasthreads}, SmartContext: {args.smartcontext}]") print(f"==========\nLoading model: {modelname} \n[Threads: {args.threads}, BlasThreads: {args.blasthreads}, SmartContext: {args.smartcontext}, ContextShift: {not (args.nocontextshift)}]")
loadok = load_model(modelname) loadok = load_model(modelname)
print("Load Model OK: " + str(loadok)) print("Load Model OK: " + str(loadok))
@ -1917,6 +1922,7 @@ if __name__ == '__main__':
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512). Setting it to -1 disables BLAS mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,32,64,128,256,512,1024,2048], default=512) parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512). Setting it to -1 disables BLAS mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,32,64,128,256,512,1024,2048], default=512)
parser.add_argument("--ropeconfig", help="If set, uses customized RoPE scaling from configured frequency scale and frequency base (e.g. --ropeconfig 0.25 10000). Otherwise, uses NTK-Aware scaling set automatically based on context size. For linear rope, simply set the freq-scale and ignore the freq-base",metavar=('[rope-freq-scale]', '[rope-freq-base]'), default=[0.0, 10000.0], type=float, nargs='+') parser.add_argument("--ropeconfig", help="If set, uses customized RoPE scaling from configured frequency scale and frequency base (e.g. --ropeconfig 0.25 10000). Otherwise, uses NTK-Aware scaling set automatically based on context size. For linear rope, simply set the freq-scale and ignore the freq-base",metavar=('[rope-freq-scale]', '[rope-freq-base]'), default=[0.0, 10000.0], type=float, nargs='+')
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true') parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
parser.add_argument("--nocontextshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true')
parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+') parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+')
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0) parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true') parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')