From da2db0302ca822700062cb35bfcbde40c1f00e20 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Tue, 19 Dec 2023 22:23:19 +0800 Subject: [PATCH] Added support for ssl cert and key --- koboldcpp.py | 67 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 298cb85d3..27bbd6553 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -409,6 +409,7 @@ currentusergenkey = "" #store a special key so polled streaming works even in mu args = None #global args gui_layers_untouched = True preloaded_story = None +sslvalid = False class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sys_version = "" @@ -976,12 +977,19 @@ def RunServerMultiThreaded(addr, port, embedded_kailite = None, embedded_kcpp_do self.start() def run(self): - global exitcounter + global exitcounter, sslvalid handler = ServerRequestHandler(addr, port, embedded_kailite, embedded_kcpp_docs) with http.server.HTTPServer((addr, port), handler, False) as self.httpd: try: self.httpd.socket = sock self.httpd.server_bind = self.server_close = lambda self: None + + if args.ssl and sslvalid: + import ssl + certpath = os.path.abspath(args.ssl[0]) + keypath = os.path.abspath(args.ssl[1]) + self.httpd.socket = ssl.wrap_socket(self.httpd.socket, keyfile=keypath, certfile=certpath, server_side=True) + self.httpd.serve_forever() except (KeyboardInterrupt,SystemExit): exitcounter = 999 @@ -1133,6 +1141,8 @@ def show_new_gui(): horde_apikey_var = ctk.StringVar(value="") horde_workername_var = ctk.StringVar(value="") usehorde_var = ctk.IntVar() + ssl_cert_var = ctk.StringVar() + ssl_key_var = ctk.StringVar() def tabbuttonaction(name): for t in tabcontent: @@ -1191,16 +1201,20 @@ def show_new_gui(): return entry, label - def makefileentry(parent, text, searchtext, var, row=0, width=250, filetypes=[], onchoosefile=None): + def makefileentry(parent, text, searchtext, var, row=0, width=200, filetypes=[], onchoosefile=None, singlerow=False): makelabel(parent, text, row) def getfilename(var, text): var.set(askopenfilename(title=text,filetypes=filetypes)) if onchoosefile: onchoosefile(var.get()) entry = ctk.CTkEntry(parent, width, textvariable=var) - entry.grid(row=row+1, column=0, padx=8, stick="nw") button = ctk.CTkButton(parent, 50, text="Browse", command= lambda a=var,b=searchtext:getfilename(a,b)) - button.grid(row=row+1, column=1, stick="nw") + if singlerow: + entry.grid(row=row, column=1, padx=8, stick="w") + button.grid(row=row, column=1, padx=144, stick="nw") + else: + entry.grid(row=row+1, column=0, padx=8, stick="nw") + button.grid(row=row+1, column=1, stick="nw") return # decided to follow yellowrose's and kalomaze's suggestions, this function will automatically try to determine GPU identifiers @@ -1546,21 +1560,24 @@ def show_new_gui(): makecheckbox(network_tab, "Remote Tunnel", remotetunnel, 3, 1) makecheckbox(network_tab, "Quiet Mode", quietmode, 4) - # horde - makelabel(network_tab, "Horde:", 5).grid(pady=10) + makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=130 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True) + makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=130, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True) - horde_name_entry, horde_name_label = makelabelentry(network_tab, "Horde Model Name:", horde_name_var, 10, 180) - horde_gen_entry, horde_gen_label = makelabelentry(network_tab, "Gen. Length:", horde_gen_var, 11, 50) - horde_context_entry, horde_context_label = makelabelentry(network_tab, "Max Context:",horde_context_var, 12, 50) - horde_apikey_entry, horde_apikey_label = makelabelentry(network_tab, "API Key (If Embedded Worker):",horde_apikey_var, 13, 180) - horde_workername_entry, horde_workername_label = makelabelentry(network_tab, "Horde Worker Name:",horde_workername_var, 14, 180) + # horde + makelabel(network_tab, "Horde:", 18).grid(pady=10) + + horde_name_entry, horde_name_label = makelabelentry(network_tab, "Horde Model Name:", horde_name_var, 20, 180) + horde_gen_entry, horde_gen_label = makelabelentry(network_tab, "Gen. Length:", horde_gen_var, 21, 50) + horde_context_entry, horde_context_label = makelabelentry(network_tab, "Max Context:",horde_context_var, 22, 50) + horde_apikey_entry, horde_apikey_label = makelabelentry(network_tab, "API Key (If Embedded Worker):",horde_apikey_var, 23, 180) + horde_workername_entry, horde_workername_label = makelabelentry(network_tab, "Horde Worker Name:",horde_workername_var, 24, 180) def togglehorde(a,b,c): labels = [horde_name_label, horde_gen_label, horde_context_label, horde_apikey_label, horde_workername_label] for idx, item in enumerate([horde_name_entry, horde_gen_entry, horde_context_entry, horde_apikey_entry, horde_workername_entry]): if usehorde_var.get() == 1: - item.grid(row=10 + idx, column = 1, padx=8, pady=1, stick="nw") - labels[idx].grid(row=10 + idx, padx=8, pady=1, stick="nw") + item.grid(row=20 + idx, column = 1, padx=8, pady=1, stick="nw") + labels[idx].grid(row=20 + idx, padx=8, pady=1, stick="nw") else: item.grid_forget() labels[idx].grid_forget() @@ -1568,7 +1585,7 @@ def show_new_gui(): basefile = os.path.basename(model_var.get()) horde_name_var.set(sanitize_string(os.path.splitext(basefile)[0])) - makecheckbox(network_tab, "Configure for Horde", usehorde_var, 6, command=togglehorde) + makecheckbox(network_tab, "Configure for Horde", usehorde_var, 19, command=togglehorde) togglehorde(1,1,1) # launch @@ -1639,6 +1656,9 @@ def show_new_gui(): args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get() + args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()]) + + args.port_param = defaultport if port_var.get()=="" else int(port_var.get()) args.host = host_var.get() args.multiuser = multiuser_var.get() @@ -1729,6 +1749,11 @@ def show_new_gui(): else: lora_var.set(dict["lora"][0]) + if "ssl" in dict and dict["ssl"]: + if len(dict["ssl"]) == 2: + ssl_cert_var.set(dict["ssl"][0]) + ssl_key_var.set(dict["ssl"][1]) + if "preloadstory" in dict and dict["preloadstory"]: preloadstory_var.set(dict["preloadstory"]) @@ -2292,11 +2317,19 @@ def main(launch_args,start_server=True): if args.port_param!=defaultport: args.port = args.port_param + global sslvalid + if args.ssl: + if len(args.ssl)==2 and isinstance(args.ssl[0], str) and os.path.exists(args.ssl[0]) and isinstance(args.ssl[1], str) and os.path.exists(args.ssl[1]): + sslvalid = True + print("SSL configuration is valid and will be used.") + else: + print("Your SSL configuration is INVALID. SSL will not be used.") epurl = "" + httpsaffix = ("https" if sslvalid else "http") if args.host=="": - epurl = f"http://localhost:{args.port}" + epurl = f"{httpsaffix}://localhost:{args.port}" else: - epurl = f"http://{args.host}:{args.port}" + epurl = f"{httpsaffix}://{args.host}:{args.port}" if not args.remotetunnel: print(f"Starting Kobold API on port {args.port} at {epurl}/api/") print(f"Starting OpenAI Compatible API on port {args.port} at {epurl}/v1/") @@ -2396,6 +2429,8 @@ if __name__ == '__main__': parser.add_argument("--foreground", help="Windows only. Sends the terminal to the foreground every time a new prompt is generated. This helps avoid some idle slowdown issues.", action='store_true') parser.add_argument("--preloadstory", help="Configures a prepared story json save file to be hosted on the server, which frontends (such as Kobold Lite) can access over the API.", default="") parser.add_argument("--quiet", help="Enable quiet mode, which hides generation inputs and outputs in the terminal. Quiet mode is automatically enabled when running --hordeconfig.", action='store_true') + parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+') + # #deprecated hidden args. they do nothing. do not use # parser.add_argument("--psutil_set_threads", action='store_true', help=argparse.SUPPRESS)