back to http.server, improved implementation

This commit is contained in:
SammCheese 2023-06-09 12:17:55 +02:00
parent 4f665cd63d
commit e6231c3055
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1
4 changed files with 196 additions and 150 deletions

View file

@ -24,9 +24,8 @@ std::string executable_path = "";
std::string lora_filename = ""; std::string lora_filename = "";
static std::string current_token = ""; bool generation_finished;
static bool new_token_available = false; std::vector<std::string> generated_tokens;
static bool finished_stream = false;
extern "C" extern "C"
{ {
@ -213,35 +212,17 @@ extern "C"
return gpttype_generate(inputs, output); return gpttype_generate(inputs, output);
} }
const char* new_token(int idx) {
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
const char* new_token() { return generated_tokens[idx].c_str();
if (new_token_available) {
new_token_available = false;
return current_token.c_str();
}
return nullptr;
} }
bool is_locked() { int get_stream_count() {
return !new_token_available; return generated_tokens.size();
} }
bool has_finished() { bool has_finished() {
return finished_stream; return generation_finished;
}
// TODO: dont duplicate code
void bind_set_stream_finished(bool status) {
finished_stream = status;
} }
} }
void receive_current_token(std::string token) {
current_token = token;
new_token_available = true;
}
void set_stream_finished(bool status) {
finished_stream = status;
}

View file

@ -50,5 +50,5 @@ struct generation_outputs
extern std::string executable_path; extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;
extern void receive_current_token(std::string token); extern std::vector<std::string> generated_tokens;
extern void set_stream_finished(bool status = true); extern bool generation_finished;

View file

@ -736,6 +736,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
params.n_batch = n_batch; params.n_batch = n_batch;
params.n_threads = n_threads; params.n_threads = n_threads;
generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
if (params.repeat_last_n < 1) if (params.repeat_last_n < 1)
{ {
params.repeat_last_n = 1; params.repeat_last_n = 1;
@ -1041,7 +1044,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
fprintf(stderr, "Failed to predict\n"); fprintf(stderr, "Failed to predict\n");
snprintf(output.text, sizeof(output.text), "%s", ""); snprintf(output.text, sizeof(output.text), "%s", "");
output.status = 0; output.status = 0;
set_stream_finished(true); generation_finished = true;
return output; return output;
} }
} }
@ -1155,7 +1158,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (stream_sse) if (stream_sse)
{ {
receive_current_token(tokenizedstr); generated_tokens.push_back(tokenizedstr);
} }
concat_output += tokenizedstr; concat_output += tokenizedstr;
} }
@ -1224,7 +1227,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2));
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
set_stream_finished(true); generation_finished = true;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output; return output;

View file

@ -5,8 +5,7 @@
import ctypes import ctypes
import os import os
import argparse import argparse
import json, sys, time, asyncio, socket import json, sys, http.server, time, asyncio, socket, threading
from aiohttp import web
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
stop_token_max = 10 stop_token_max = 10
@ -137,6 +136,9 @@ def init_library():
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs handle.generate.restype = generation_outputs
handle.new_token.restype = ctypes.c_char_p handle.new_token.restype = ctypes.c_char_p
handle.new_token.argtypes = [ctypes.c_int]
handle.get_stream_count.restype = ctypes.c_int
handle.has_finished.restype = ctypes.c_bool
def load_model(model_filename): def load_model(model_filename):
inputs = load_model_inputs() inputs = load_model_inputs()
@ -215,25 +217,23 @@ modelbusy = False
defaultport = 5001 defaultport = 5001
KcppVersion = "1.29" KcppVersion = "1.29"
class ServerRequestHandler: class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sys_version = "" sys_version = ""
server_version = "ConcedoLlamaForKoboldServer" server_version = "ConcedoLlamaForKoboldServer"
app = web.Application()
def __init__(self, addr, port, embedded_kailite): def __init__(self, addr, port, embedded_kailite):
self.addr = addr self.addr = addr
self.port = port self.port = port
self.embedded_kailite = embedded_kailite self.embedded_kailite = embedded_kailite
def __call__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def generate_text(self, newprompt, genparams, basic_api_flag): async def generate_text(self, newprompt, genparams, basic_api_flag):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
def run_blocking(): def run_blocking():
# Reset finished status before generating
handle.bind_set_stream_finished(False)
if basic_api_flag: if basic_api_flag:
return generate( return generate(
prompt=newprompt, prompt=newprompt,
@ -249,7 +249,7 @@ class ServerRequestHandler:
seed=genparams.get('sampler_seed', -1), seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []) stop_sequence=genparams.get('stop_sequence', [])
) )
else:
return generate(prompt=newprompt, return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx), max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50), max_length=genparams.get('max_length', 50),
@ -277,104 +277,139 @@ class ServerRequestHandler:
print(f"Generate: Error while generating {e}") print(f"Generate: Error while generating {e}")
async def send_sse_event(self, response, event, data): async def send_sse_event(self, event, data):
await response.write(f'event: {event}\n'.encode()) self.wfile.write(f'event: {event}\n'.encode())
await response.write(f'data: {data}\n\n'.encode()) self.wfile.write(f'data: {data}\n\n'.encode())
async def handle_sse_stream(self, request):
response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) async def handle_sse_stream(self):
await response.prepare(request) self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.end_headers()
current_token = 0;
while not handle.has_finished(): while not handle.has_finished():
if not handle.is_locked(): if current_token < handle.get_stream_count():
token = ctypes.string_at(handle.new_token()).decode('utf-8') token = handle.new_token(current_token)
event_data = {"token": token}
if token is None: # Token isnt ready yet, received nullpointer
continue
current_token += 1
tokenStr = ctypes.string_at(token).decode('utf-8')
event_data = {"token": tokenStr}
event_str = json.dumps(event_data) event_str = json.dumps(event_data)
await self.send_sse_event(response, "message", event_str) await self.send_sse_event("message", event_str)
await asyncio.sleep(0) await asyncio.sleep(0)
await response.write_eof() await self.wfile.close()
await response.force_close()
async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag):
async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag):
tasks = [] tasks = []
if stream_flag: if stream_flag:
tasks.append(self.handle_sse_stream(request,)) tasks.append(self.handle_sse_stream())
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
tasks.append(generate_task) tasks.append(generate_task)
try: try:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
print("done")
generate_result = generate_task.result() generate_result = generate_task.result()
return generate_result return generate_result
except Exception as e: except Exception as e:
print(e) print(e)
async def handle_get(self, request): def do_GET(self):
global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock
path = request.path.rstrip('/') self.path = self.path.rstrip('/')
response_body = None
if path in ["", "/?"] or path.startswith(('/?', '?')): if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
if args.stream and not "streaming=1" in path: if args.stream and not "streaming=1" in self.path:
path = path.replace("streaming=0", "") self.path = self.path.replace("streaming=0","")
if path.startswith(('/?', '?')): if self.path.startswith(('/?','?')):
path += "&streaming=1" self.path += "&streaming=1"
else: else:
path = path + "?streaming=1" self.path = self.path + "?streaming=1"
self.send_response(302)
self.send_header("Location", self.path)
self.end_headers()
print("Force redirect to streaming mode, as --stream is set.")
return None
if self.embedded_kailite is None: if self.embedded_kailite is None:
return web.Response(body=f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.".encode()) response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode()
else: else:
return web.Response(body=self.embedded_kailite, content_type='text/html') response_body = self.embedded_kailite
elif path.endswith(('/api/v1/model', '/api/latest/model')): elif self.path.endswith(('/api/v1/model', '/api/latest/model')):
return web.json_response({'result': friendlymodelname}) response_body = (json.dumps({'result': friendlymodelname }).encode())
elif path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
return web.json_response({"value": maxlen}) response_body = (json.dumps({"value": maxlen}).encode())
elif path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
return web.json_response({"value": maxctx}) response_body = (json.dumps({"value": maxctx}).encode())
elif path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
return web.json_response({"value": ""}) response_body = (json.dumps({"value":""}).encode())
elif path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
return web.json_response({"values": []}) response_body = (json.dumps({"values": []}).encode())
elif path.endswith(('/api/v1/info/version', '/api/latest/info/version')): elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
return web.json_response({"result": "1.2.2"}) response_body = (json.dumps({"result":"1.2.2"}).encode())
elif path.endswith(('/api/extra/version')): elif self.path.endswith(('/api/extra/version')):
return web.json_response({"result": "KoboldCpp", "version": KcppVersion}) response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
return web.Response(status=404, text="Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.") if response_body is None:
self.send_response(404)
self.end_headers()
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
self.wfile.write(rp.encode())
else:
self.send_response(200)
self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
self.wfile.write(response_body)
return
async def handle_post(self, request): def do_POST(self):
global modelbusy global modelbusy
body = await request.content.read() content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length)
basic_api_flag = False basic_api_flag = False
kai_api_flag = False kai_api_flag = False
kai_sse_stream_flag = False kai_sse_stream_flag = False
path = request.path.rstrip('/') self.path = self.path.rstrip('/')
if modelbusy: if modelbusy:
return web.json_response( self.send_response(503)
{"detail": {"msg": "Server is busy; please try again later.", "type": "service_unavailable"}}, self.end_headers()
status=503, self.wfile.write(json.dumps({"detail": {
) "msg": "Server is busy; please try again later.",
"type": "service_unavailable",
}}).encode())
return
if path.endswith('/request'): if self.path.endswith('/request'):
basic_api_flag = True basic_api_flag = True
if path.endswith(('/api/v1/generate', '/api/latest/generate')): if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True kai_api_flag = True
if path.endswith('/api/extra/generate/stream'): if self.path.endswith('/api/extra/generate/stream'):
kai_api_flag = True kai_api_flag = True
kai_sse_stream_flag = True kai_sse_stream_flag = True
@ -383,66 +418,94 @@ class ServerRequestHandler:
try: try:
genparams = json.loads(body) genparams = json.loads(body)
except ValueError as e: except ValueError as e:
return web.Response(status=503) return self.send_response(503)
utfprint("\nInput: " + json.dumps(genparams)) utfprint("\nInput: " + json.dumps(genparams))
modelbusy = True modelbusy = True
if kai_api_flag: if kai_api_flag:
fullprompt = genparams.get('prompt', "") fullprompt = genparams.get('prompt', "")
else: else:
fullprompt = genparams.get('text', "") fullprompt = genparams.get('text', "")
newprompt = fullprompt newprompt = fullprompt
gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag) gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
try:
self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
modelbusy = False modelbusy = False
if not kai_sse_stream_flag: return
return web.Response(body=json.dumps(gen).encode())
return web.Response();
return web.Response(status=404) self.send_response(404)
self.end_headers()
async def handle_options(self):
return web.Response()
async def handle_head(self): def do_OPTIONS(self):
return web.Response() self.send_response(200)
self.end_headers()
async def start_server(self): def do_HEAD(self):
self.send_response(200)
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
if "/api" in self.path:
if self.path.endswith("/stream"):
self.send_header('Content-type', 'text/event-stream')
self.send_header('Content-type', 'application/json')
else:
self.send_header('Content-type', 'text/html')
return super(ServerRequestHandler, self).end_headers()
def RunServerMultiThreaded(addr, port, embedded_kailite = None):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.addr, self.port)) sock.bind((addr, port))
sock.listen(5) sock.listen(5)
self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) class Thread(threading.Thread):
self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) def __init__(self, i):
self.app.router.add_route('OPTIONS', '/', self.handle_options) threading.Thread.__init__(self)
self.app.router.add_route('HEAD', '/', self.handle_head) self.i = i
self.daemon = True
self.start()
runner = web.AppRunner(self.app) def run(self):
await runner.setup()
site = web.SockSite(runner, sock)
await site.start()
# Keep Alive
try:
while True:
await asyncio.sleep(3600)
except KeyboardInterrupt:
await runner.cleanup()
await site.stop()
await sys.exit(0)
finally:
await runner.cleanup()
await site.stop()
await sys.exit(0)
async def run_server(addr, port, embedded_kailite=None):
handler = ServerRequestHandler(addr, port, embedded_kailite) handler = ServerRequestHandler(addr, port, embedded_kailite)
await handler.start_server() with http.server.HTTPServer((addr, port), handler, False) as self.httpd:
try:
self.httpd.socket = sock
self.httpd.server_bind = self.server_close = lambda self: None
self.httpd.serve_forever()
except (KeyboardInterrupt,SystemExit):
self.httpd.server_close()
sys.exit(0)
finally:
self.httpd.server_close()
sys.exit(0)
def stop(self):
self.httpd.server_close()
numThreads = 6
threadArr = []
for i in range(numThreads):
threadArr.append(Thread(i))
while 1:
try:
time.sleep(10)
except KeyboardInterrupt:
for i in range(numThreads):
threadArr[i].stop()
sys.exit(0)
def show_gui(): def show_gui():
@ -514,15 +577,14 @@ def show_gui():
unbantokens = tk.IntVar() unbantokens = tk.IntVar()
highpriority = tk.IntVar() highpriority = tk.IntVar()
disablemmap = tk.IntVar() disablemmap = tk.IntVar()
frm3 = tk.Frame(root) frameD = tk.Frame(root)
tk.Checkbutton(frm3, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0)
tk.Checkbutton(frm3, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1)
tk.Checkbutton(frm3, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0)
tk.Checkbutton(frm3, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1)
tk.Checkbutton(frm3, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0)
tk.Checkbutton(frm3, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1)
frameD.grid(row=5,column=0,pady=4)
frm3.grid(row=5,column=0,pady=4)
# Create button, it will change label text # Create button, it will change label text
tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0) tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0)
@ -702,7 +764,7 @@ def main(args):
except: except:
print("--launch was set, but could not launch web browser automatically.") print("--launch was set, but could not launch web browser automatically.")
print(f"Please connect to custom endpoint at {epurl}") print(f"Please connect to custom endpoint at {epurl}")
asyncio.run(run_server(args.host, args.port, embedded_kailite)) asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite))
if __name__ == '__main__': if __name__ == '__main__':
print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually