added preloadstory

This commit is contained in:
Concedo 2023-11-10 13:05:22 +08:00
parent 6870c31933
commit be92cfa125
5 changed files with 65 additions and 7 deletions

View file

@ -217,6 +217,9 @@ extern "C"
int get_last_token_count() { int get_last_token_count() {
return last_token_count; return last_token_count;
} }
int get_total_gens() {
return total_gens;
}
int get_last_stop_reason() { int get_last_stop_reason() {
return (int)last_stop_reason; return (int)last_stop_reason;
} }

View file

@ -91,4 +91,5 @@ extern bool generation_finished;
extern float last_eval_time; extern float last_eval_time;
extern float last_process_time; extern float last_process_time;
extern int last_token_count; extern int last_token_count;
extern int total_gens;
extern stop_reason last_stop_reason; extern stop_reason last_stop_reason;

View file

@ -39,6 +39,7 @@ bool generation_finished;
float last_process_time = 0; float last_process_time = 0;
float last_eval_time = 0; float last_eval_time = 0;
int last_token_count = 0; int last_token_count = 0;
int total_gens = 0;
stop_reason last_stop_reason = stop_reason::INVALID; stop_reason last_stop_reason = stop_reason::INVALID;
std::vector<std::string> generated_tokens; std::vector<std::string> generated_tokens;
@ -597,8 +598,8 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
//if passed, save beginning of LCQ from old ctx as p1 //if passed, save beginning of LCQ from old ctx as p1
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal //remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
const int ShortfallThreshold = 200 + (nctx/40); //dont trigger shifting if the distance between trimstart and currhead < this const int ShortfallThreshold = 200 + (nctx/20); //dont trigger shifting if the distance between trimstart and currhead < this
const int SlackAllowance = 50 + (nctx/80); //in case the end text is slightly modified, be forgiving const int SlackAllowance = 50 + (nctx/60); //in case the end text is slightly modified, be forgiving
int trimstart = 0; int trimstart = 0;
int new_tokens_len = new_context_tokens.size(); int new_tokens_len = new_context_tokens.size();
@ -1955,6 +1956,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
last_eval_time = pt2; last_eval_time = pt2;
last_process_time = pt1; last_process_time = pt1;
last_token_count = realnpredict; last_token_count = realnpredict;
total_gens += 1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output; return output;

View file

@ -6,7 +6,7 @@ It requires no dependencies, installation or setup.
Just copy this single static HTML file anywhere and open it in a browser, or from a webserver. Just copy this single static HTML file anywhere and open it in a browser, or from a webserver.
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 94 Current version: 95
-Concedo -Concedo
--> -->
@ -3168,6 +3168,7 @@ Current version: 94
const koboldcpp_abort_endpoint = "/api/extra/abort"; const koboldcpp_abort_endpoint = "/api/extra/abort";
const koboldcpp_check_endpoint = "/api/extra/generate/check"; const koboldcpp_check_endpoint = "/api/extra/generate/check";
const koboldcpp_truemaxctxlen_endpoint = "/api/extra/true_max_context_length"; const koboldcpp_truemaxctxlen_endpoint = "/api/extra/true_max_context_length";
const koboldcpp_preloadstory_endpoint = "/api/extra/preloadstory";
const oai_models_endpoint = "/models"; const oai_models_endpoint = "/models";
const oai_submit_endpoint = "/completions"; const oai_submit_endpoint = "/completions";
@ -6092,6 +6093,28 @@ Current version: 94
console.log("Failed to get true max ctx: " + error); console.log("Failed to get true max ctx: " + error);
}); });
//and check if there's a kcpp savefile preloaded
let urls5 = [
apply_proxy_url(tmpep + koboldcpp_preloadstory_endpoint),
];
Promise.all(urls5.map(url => fetch(url)
.then(response => response.json())))
.then(values5 => {
let tmpstory = values5[0];
let is_kai = !(tmpstory.prompt==null);
if(is_kai)
{
let safe_to_overwrite = (gametext_arr.length == 0 && current_memory == "" && current_anote == "" && current_wi.length == 0 && redo_arr.length == 0);
if (localsettings.persist_session && !safe_to_overwrite) {
console.log("Preload story: Unsafe to overwrite");
} else {
kai_json_load(tmpstory, false);
}
}
}).catch(error => {
console.log("Failed to get preloaded story: " + error);
});
}else{ }else{
console.log("Unknown KoboldCpp Check Response: " + data); console.log("Unknown KoboldCpp Check Response: " + data);
} }
@ -7233,7 +7256,8 @@ Current version: 94
toggle_invert_colors(); toggle_invert_colors();
hide_popups(); hide_popups();
render_gametext(); //need to always autosave, so that we can switch back to non persistent sessions autosave();//need to always autosave, so that we can switch back to non persistent sessions
render_gametext(false);
} }
function toggle_instruct_tag_format() function toggle_instruct_tag_format()

View file

@ -214,6 +214,7 @@ def init_library():
handle.get_last_eval_time.restype = ctypes.c_float handle.get_last_eval_time.restype = ctypes.c_float
handle.get_last_process_time.restype = ctypes.c_float handle.get_last_process_time.restype = ctypes.c_float
handle.get_last_token_count.restype = ctypes.c_int handle.get_last_token_count.restype = ctypes.c_int
handle.get_total_gens.restype = ctypes.c_int
handle.get_last_stop_reason.restype = ctypes.c_int handle.get_last_stop_reason.restype = ctypes.c_int
handle.abort_generate.restype = ctypes.c_bool handle.abort_generate.restype = ctypes.c_bool
handle.token_count.restype = ctypes.c_int handle.token_count.restype = ctypes.c_int
@ -401,6 +402,7 @@ totalgens = 0
currentusergenkey = "" #store a special key so polled streaming works even in multiuser currentusergenkey = "" #store a special key so polled streaming works even in multiuser
args = None #global args args = None #global args
gui_layers_untouched = True gui_layers_untouched = True
preloaded_story = None
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sys_version = "" sys_version = ""
@ -618,7 +620,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self): def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story
self.path = self.path.rstrip('/') self.path = self.path.rstrip('/')
response_body = None response_body = None
content_type = 'application/json' content_type = 'application/json'
@ -658,8 +660,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
lastp = handle.get_last_process_time() lastp = handle.get_last_process_time()
laste = handle.get_last_eval_time() laste = handle.get_last_eval_time()
lastc = handle.get_last_token_count() lastc = handle.get_last_token_count()
totalgens = handle.get_total_gens()
stopreason = handle.get_last_stop_reason() stopreason = handle.get_last_stop_reason()
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode()) response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "total_gens":totalgens, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode())
elif self.path.endswith('/api/extra/generate/check'): elif self.path.endswith('/api/extra/generate/check'):
pendtxtStr = "" pendtxtStr = ""
@ -677,6 +680,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode() response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode()
else: else:
response_body = self.embedded_kcpp_docs response_body = self.embedded_kcpp_docs
elif self.path=="/api/extra/preloadstory":
if preloaded_story is None:
response_body = (json.dumps({}).encode())
else:
response_body = preloaded_story
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
self.path = "/api" self.path = "/api"
self.send_response(302) self.send_response(302)
@ -1008,7 +1017,8 @@ def show_new_gui():
model_var = ctk.StringVar() model_var = ctk.StringVar()
lora_var = ctk.StringVar() lora_var = ctk.StringVar()
lora_base_var = ctk.StringVar() lora_base_var = ctk.StringVar()
preloadstory_var = ctk.StringVar()
port_var = ctk.StringVar(value=defaultport) port_var = ctk.StringVar(value=defaultport)
host_var = ctk.StringVar(value="") host_var = ctk.StringVar(value="")
@ -1404,6 +1414,7 @@ def show_new_gui():
makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=autoset_gpu_layers) makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=autoset_gpu_layers)
makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3) makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3)
makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5) makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5)
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 7)
# Network Tab # Network Tab
network_tab = tabcontent["Network"] network_tab = tabcontent["Network"]
@ -1505,6 +1516,7 @@ def show_new_gui():
args.model_param = None if model_var.get() == "" else model_var.get() args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])
args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get()
args.port_param = defaultport if port_var.get()=="" else int(port_var.get()) args.port_param = defaultport if port_var.get()=="" else int(port_var.get())
args.host = host_var.get() args.host = host_var.get()
@ -1595,6 +1607,9 @@ def show_new_gui():
else: else:
lora_var.set(dict["lora"][0]) lora_var.set(dict["lora"][0])
if "preloadstory" in dict and dict["preloadstory"]:
preloadstory_var.set(dict["preloadstory"])
if "port_param" in dict and dict["port_param"]: if "port_param" in dict and dict["port_param"]:
port_var.set(dict["port_param"]) port_var.set(dict["port_param"])
@ -1963,6 +1978,7 @@ def unload_libs():
del handle.get_last_eval_time del handle.get_last_eval_time
del handle.get_last_process_time del handle.get_last_process_time
del handle.get_last_token_count del handle.get_last_token_count
del handle.get_total_gens
del handle.get_last_stop_reason del handle.get_last_stop_reason
del handle.abort_generate del handle.abort_generate
del handle.token_count del handle.token_count
@ -2018,6 +2034,17 @@ def main(launch_args,start_server=True):
time.sleep(3) time.sleep(3)
sys.exit(2) sys.exit(2)
#try to read story if provided
if args.preloadstory:
if isinstance(args.preloadstory, str) and os.path.exists(args.preloadstory):
print(f"Preloading saved story {args.preloadstory} into server...")
with open(args.preloadstory, mode='rb') as f:
global preloaded_story
preloaded_story = f.read()
print("Saved story preloaded.")
else:
print(f"Warning: Saved story file {args.preloadstory} invalid or not found. No story will be preloaded into server.")
# sanitize and replace the default vanity name. remember me.... # sanitize and replace the default vanity name. remember me....
if args.model_param!="": if args.model_param!="":
newmdldisplayname = os.path.basename(args.model_param) newmdldisplayname = os.path.basename(args.model_param)
@ -2201,6 +2228,7 @@ if __name__ == '__main__':
parser.add_argument("--multiuser", help="Runs in multiuser mode, which queues incoming requests instead of blocking them.", action='store_true') parser.add_argument("--multiuser", help="Runs in multiuser mode, which queues incoming requests instead of blocking them.", action='store_true')
parser.add_argument("--remotetunnel", help="Uses Cloudflare to create a remote tunnel, allowing you to access koboldcpp remotely over the internet even behind a firewall.", action='store_true') parser.add_argument("--remotetunnel", help="Uses Cloudflare to create a remote tunnel, allowing you to access koboldcpp remotely over the internet even behind a firewall.", action='store_true')
parser.add_argument("--foreground", help="Windows only. Sends the terminal to the foreground every time a new prompt is generated. This helps avoid some idle slowdown issues.", action='store_true') parser.add_argument("--foreground", help="Windows only. Sends the terminal to the foreground every time a new prompt is generated. This helps avoid some idle slowdown issues.", action='store_true')
parser.add_argument("--preloadstory", help="Configures a prepared story json save file to be hosted on the server, which frontends (such as Kobold Lite) can access over the API.", default="")
# #deprecated hidden args. they do nothing. do not use # #deprecated hidden args. they do nothing. do not use
# parser.add_argument("--psutil_set_threads", action='store_true', help=argparse.SUPPRESS) # parser.add_argument("--psutil_set_threads", action='store_true', help=argparse.SUPPRESS)