diff --git a/expose.cpp b/expose.cpp index d385ffcb7..aeb7066f1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -217,6 +217,9 @@ extern "C" int get_last_token_count() { return last_token_count; } + int get_total_gens() { + return total_gens; + } int get_last_stop_reason() { return (int)last_stop_reason; } diff --git a/expose.h b/expose.h index e3c069a66..9f686f75c 100644 --- a/expose.h +++ b/expose.h @@ -91,4 +91,5 @@ extern bool generation_finished; extern float last_eval_time; extern float last_process_time; extern int last_token_count; +extern int total_gens; extern stop_reason last_stop_reason; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index c3d41e3ed..510536f64 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -39,6 +39,7 @@ bool generation_finished; float last_process_time = 0; float last_eval_time = 0; int last_token_count = 0; +int total_gens = 0; stop_reason last_stop_reason = stop_reason::INVALID; std::vector generated_tokens; @@ -597,8 +598,8 @@ void PurgeMissingTokens(llama_context * ctx, std::vector ¤t_context_t //if passed, save beginning of LCQ from old ctx as p1 //remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal - const int ShortfallThreshold = 200 + (nctx/40); //dont trigger shifting if the distance between trimstart and currhead < this - const int SlackAllowance = 50 + (nctx/80); //in case the end text is slightly modified, be forgiving + const int ShortfallThreshold = 200 + (nctx/20); //dont trigger shifting if the distance between trimstart and currhead < this + const int SlackAllowance = 50 + (nctx/60); //in case the end text is slightly modified, be forgiving int trimstart = 0; int new_tokens_len = new_context_tokens.size(); @@ -1955,6 +1956,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o last_eval_time = pt2; last_process_time = pt1; last_token_count = realnpredict; + total_gens += 1; snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/klite.embd b/klite.embd index f61439a74..7b4587fae 100644 --- a/klite.embd +++ b/klite.embd @@ -6,7 +6,7 @@ It requires no dependencies, installation or setup. Just copy this single static HTML file anywhere and open it in a browser, or from a webserver. Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. -Current version: 94 +Current version: 95 -Concedo --> @@ -3168,6 +3168,7 @@ Current version: 94 const koboldcpp_abort_endpoint = "/api/extra/abort"; const koboldcpp_check_endpoint = "/api/extra/generate/check"; const koboldcpp_truemaxctxlen_endpoint = "/api/extra/true_max_context_length"; + const koboldcpp_preloadstory_endpoint = "/api/extra/preloadstory"; const oai_models_endpoint = "/models"; const oai_submit_endpoint = "/completions"; @@ -6092,6 +6093,28 @@ Current version: 94 console.log("Failed to get true max ctx: " + error); }); + //and check if there's a kcpp savefile preloaded + let urls5 = [ + apply_proxy_url(tmpep + koboldcpp_preloadstory_endpoint), + ]; + Promise.all(urls5.map(url => fetch(url) + .then(response => response.json()))) + .then(values5 => { + let tmpstory = values5[0]; + let is_kai = !(tmpstory.prompt==null); + if(is_kai) + { + let safe_to_overwrite = (gametext_arr.length == 0 && current_memory == "" && current_anote == "" && current_wi.length == 0 && redo_arr.length == 0); + if (localsettings.persist_session && !safe_to_overwrite) { + console.log("Preload story: Unsafe to overwrite"); + } else { + kai_json_load(tmpstory, false); + } + } + }).catch(error => { + console.log("Failed to get preloaded story: " + error); + }); + }else{ console.log("Unknown KoboldCpp Check Response: " + data); } @@ -7233,7 +7256,8 @@ Current version: 94 toggle_invert_colors(); hide_popups(); - render_gametext(); //need to always autosave, so that we can switch back to non persistent sessions + autosave();//need to always autosave, so that we can switch back to non persistent sessions + render_gametext(false); } function toggle_instruct_tag_format() diff --git a/koboldcpp.py b/koboldcpp.py index 0a1facd92..5d4e9f404 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -214,6 +214,7 @@ def init_library(): handle.get_last_eval_time.restype = ctypes.c_float handle.get_last_process_time.restype = ctypes.c_float handle.get_last_token_count.restype = ctypes.c_int + handle.get_total_gens.restype = ctypes.c_int handle.get_last_stop_reason.restype = ctypes.c_int handle.abort_generate.restype = ctypes.c_bool handle.token_count.restype = ctypes.c_int @@ -401,6 +402,7 @@ totalgens = 0 currentusergenkey = "" #store a special key so polled streaming works even in multiuser args = None #global args gui_layers_untouched = True +preloaded_story = None class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sys_version = "" @@ -618,7 +620,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): def do_GET(self): - global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens + global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story self.path = self.path.rstrip('/') response_body = None content_type = 'application/json' @@ -658,8 +660,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): lastp = handle.get_last_process_time() laste = handle.get_last_eval_time() lastc = handle.get_last_token_count() + totalgens = handle.get_total_gens() stopreason = handle.get_last_stop_reason() - response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode()) + response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "total_gens":totalgens, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode()) elif self.path.endswith('/api/extra/generate/check'): pendtxtStr = "" @@ -677,6 +680,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode() else: response_body = self.embedded_kcpp_docs + + elif self.path=="/api/extra/preloadstory": + if preloaded_story is None: + response_body = (json.dumps({}).encode()) + else: + response_body = preloaded_story elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): self.path = "/api" self.send_response(302) @@ -1008,7 +1017,8 @@ def show_new_gui(): model_var = ctk.StringVar() lora_var = ctk.StringVar() - lora_base_var = ctk.StringVar() + lora_base_var = ctk.StringVar() + preloadstory_var = ctk.StringVar() port_var = ctk.StringVar(value=defaultport) host_var = ctk.StringVar(value="") @@ -1404,6 +1414,7 @@ def show_new_gui(): makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=autoset_gpu_layers) makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3) makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5) + makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 7) # Network Tab network_tab = tabcontent["Network"] @@ -1505,6 +1516,7 @@ def show_new_gui(): args.model_param = None if model_var.get() == "" else model_var.get() args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) + args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get() args.port_param = defaultport if port_var.get()=="" else int(port_var.get()) args.host = host_var.get() @@ -1595,6 +1607,9 @@ def show_new_gui(): else: lora_var.set(dict["lora"][0]) + if "preloadstory" in dict and dict["preloadstory"]: + preloadstory_var.set(dict["preloadstory"]) + if "port_param" in dict and dict["port_param"]: port_var.set(dict["port_param"]) @@ -1963,6 +1978,7 @@ def unload_libs(): del handle.get_last_eval_time del handle.get_last_process_time del handle.get_last_token_count + del handle.get_total_gens del handle.get_last_stop_reason del handle.abort_generate del handle.token_count @@ -2018,6 +2034,17 @@ def main(launch_args,start_server=True): time.sleep(3) sys.exit(2) + #try to read story if provided + if args.preloadstory: + if isinstance(args.preloadstory, str) and os.path.exists(args.preloadstory): + print(f"Preloading saved story {args.preloadstory} into server...") + with open(args.preloadstory, mode='rb') as f: + global preloaded_story + preloaded_story = f.read() + print("Saved story preloaded.") + else: + print(f"Warning: Saved story file {args.preloadstory} invalid or not found. No story will be preloaded into server.") + # sanitize and replace the default vanity name. remember me.... if args.model_param!="": newmdldisplayname = os.path.basename(args.model_param) @@ -2201,6 +2228,7 @@ if __name__ == '__main__': parser.add_argument("--multiuser", help="Runs in multiuser mode, which queues incoming requests instead of blocking them.", action='store_true') parser.add_argument("--remotetunnel", help="Uses Cloudflare to create a remote tunnel, allowing you to access koboldcpp remotely over the internet even behind a firewall.", action='store_true') parser.add_argument("--foreground", help="Windows only. Sends the terminal to the foreground every time a new prompt is generated. This helps avoid some idle slowdown issues.", action='store_true') + parser.add_argument("--preloadstory", help="Configures a prepared story json save file to be hosted on the server, which frontends (such as Kobold Lite) can access over the API.", default="") # #deprecated hidden args. they do nothing. do not use # parser.add_argument("--psutil_set_threads", action='store_true', help=argparse.SUPPRESS)