draft: token streaming

This commit is contained in:
SammCheese 2023-06-07 00:48:00 +02:00
parent a6a0fa338a
commit 97971291e9
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1
4 changed files with 183 additions and 170 deletions

View file

@ -23,6 +23,11 @@
std::string executable_path = ""; std::string executable_path = "";
std::string lora_filename = ""; std::string lora_filename = "";
static std::string current_token = "";
static bool new_token_available = false;
static bool finished_stream = false;
extern "C" extern "C"
{ {
@ -205,6 +210,33 @@ extern "C"
generation_outputs generate(const generation_inputs inputs, generation_outputs &output) generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
{ {
finished_stream = false;
return gpttype_generate(inputs, output); return gpttype_generate(inputs, output);
} }
const char* new_token() {
if (new_token_available) {
new_token_available = false;
return current_token.c_str();
}
return nullptr;
}
bool is_locked() {
return !new_token_available;
}
bool has_finished() {
return finished_stream;
}
}
void receive_current_token(std::string token) {
current_token = token;
new_token_available = true;
}
void set_stream_finished() {
finished_stream = true;
} }

View file

@ -18,6 +18,7 @@ struct load_model_inputs
const int clblast_info = 0; const int clblast_info = 0;
const int blasbatchsize = 512; const int blasbatchsize = 512;
const bool debugmode; const bool debugmode;
const bool stream_sse;
const int forceversion = 0; const int forceversion = 0;
const int gpulayers = 0; const int gpulayers = 0;
}; };
@ -48,3 +49,6 @@ struct generation_outputs
extern std::string executable_path; extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;
extern void receive_current_token(std::string token);
extern void set_stream_finished();

View file

@ -63,6 +63,7 @@ static bool useSmartContext = false;
static bool unbanTokens = false; static bool unbanTokens = false;
static int blasbatchsize = 512; static int blasbatchsize = 512;
static bool debugmode = false; static bool debugmode = false;
static bool stream_sse = true;
static std::string modelname; static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens; static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens; static std::vector<gpt_vocab::id> current_context_tokens;
@ -1040,6 +1041,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
fprintf(stderr, "Failed to predict\n"); fprintf(stderr, "Failed to predict\n");
snprintf(output.text, sizeof(output.text), "%s", ""); snprintf(output.text, sizeof(output.text), "%s", "");
output.status = 0; output.status = 0;
set_stream_finished();
return output; return output;
} }
} }
@ -1149,7 +1151,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
for (auto id : embd) for (auto id : embd)
{ {
concat_output += FileFormatTokenizeID(id,file_format); std::string tokenizedstr = FileFormatTokenizeID(id, file_format);
if (stream_sse)
{
receive_current_token(tokenizedstr);
}
concat_output += tokenizedstr;
} }
if (startedsampling) if (startedsampling)
@ -1216,6 +1224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2));
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
set_stream_finished();
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output; return output;

View file

@ -5,7 +5,8 @@
import ctypes import ctypes
import os import os
import argparse import argparse
import json, http.server, threading, socket, sys, time import json, http.server, threading, socket, sys, time, asyncio
from aiohttp import web
stop_token_max = 10 stop_token_max = 10
@ -134,6 +135,7 @@ def init_library():
handle.load_model.restype = ctypes.c_bool handle.load_model.restype = ctypes.c_bool
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs handle.generate.restype = generation_outputs
handle.new_token.restype = ctypes.c_char_p
def load_model(model_filename): def load_model(model_filename):
inputs = load_model_inputs() inputs = load_model_inputs()
@ -212,105 +214,145 @@ modelbusy = False
defaultport = 5001 defaultport = 5001
KcppVersion = "1.29" KcppVersion = "1.29"
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): class ServerRequestHandler:
sys_version = "" sys_version = ""
server_version = "ConcedoLlamaForKoboldServer" server_version = "ConcedoLlamaForKoboldServer"
app = web.Application()
def __init__(self, addr, port, embedded_kailite): def __init__(self, addr, port, embedded_kailite):
self.addr = addr self.addr = addr
self.port = port self.port = port
self.embedded_kailite = embedded_kailite self.embedded_kailite = embedded_kailite
def __call__(self, *args, **kwargs): async def send_sse_event(self, response, event, data):
super().__init__(*args, **kwargs) await response.write(f'event: {event}\n'.encode())
await response.write(f'data: {data}\n\n'.encode())
def do_GET(self):
global maxctx, maxlen, friendlymodelname, KcppVersion
self.path = self.path.rstrip('/')
response_body = None
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / async def handle_sse_stream(self, request):
if args.stream and not "streaming=1" in self.path: response = web.StreamResponse(headers={"Content-Type": "text/event-stream"})
self.path = self.path.replace("streaming=0","") await response.prepare(request)
if self.path.startswith(('/?','?')):
self.path += "&streaming=1" stream_finished = False
while True:
if handle.has_finished():
stream_finished = True
if not handle.is_locked():
token = ctypes.string_at(handle.new_token()).decode('utf-8')
event_data = {"finished": stream_finished, "token": token}
event_str = f"data: {json.dumps(event_data)}"
await self.send_sse_event(response, "message", event_str)
print(token)
print(event_data)
if stream_finished:
break
async def generate_text(self, newprompt, genparams):
recvtxt = generate(
prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
res = {"results": [{"text": recvtxt}]}
try:
return web.json_response(res)
except:
print("Generate: The response could not be sent, maybe the connection was terminated?")
async def handle_request(self, request, genparams, newprompt, stream_flag):
if stream_flag:
self.handle_sse_stream(request)
# RUN THESE CONCURRENTLY WITHOUT BLOCKING EACHOTHER
self.generate_text(newprompt, genparams)
async def handle_get(self, request):
global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock
path = request.path.rstrip('/')
if path in ["", "/?"] or path.startswith(('/?','?')):
if args.stream and not "streaming=1" in path:
path = path.replace("streaming=0","")
if path.startswith(('/?','?')):
path += "&streaming=1"
else: else:
self.path = self.path + "?streaming=1" path = path + "?streaming=1"
self.send_response(302) raise web.HTTPFound(path)
self.send_header("Location", self.path)
self.end_headers()
print("Force redirect to streaming mode, as --stream is set.")
return None
if self.embedded_kailite is None: if self.embedded_kailite is None:
response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode() return web.Response(
text="Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect."
)
else: else:
response_body = self.embedded_kailite return web.Response(body=self.embedded_kailite)
elif self.path.endswith(('/api/v1/model', '/api/latest/model')): elif path.endswith(('/api/v1/model', '/api/latest/model')):
response_body = (json.dumps({'result': friendlymodelname }).encode()) return web.json_response({'result': friendlymodelname})
elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): elif path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
response_body = (json.dumps({"value": maxlen}).encode()) return web.json_response({"value": maxlen})
elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): elif path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
response_body = (json.dumps({"value": maxctx}).encode()) return web.json_response({"value": maxctx})
elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): elif path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
response_body = (json.dumps({"value":""}).encode()) return web.json_response({"value": ""})
elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): elif path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
response_body = (json.dumps({"values": []}).encode()) return web.json_response({"values": []})
elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')): elif path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
response_body = (json.dumps({"result":"1.2.2"}).encode()) return web.json_response({"result": "1.2.2"})
elif self.path.endswith(('/api/extra/version')): elif path.endswith(('/api/extra/version')):
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode()) return web.json_response({"result": "KoboldCpp", "version": KcppVersion})
if response_body is None: return web.Response(status=404, text="Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.")
self.send_response(404)
self.end_headers()
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
self.wfile.write(rp.encode())
else:
self.send_response(200)
self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
self.wfile.write(response_body)
return
def do_POST(self): async def handle_post(self, request):
global modelbusy global modelbusy
content_length = int(self.headers['Content-Length']) body = await request.content.read()
body = self.rfile.read(content_length)
basic_api_flag = False basic_api_flag = False
kai_api_flag = False kai_api_flag = False
self.path = self.path.rstrip('/') kai_stream_flag = True
path = request.path.rstrip('/')
print(request)
if modelbusy: if modelbusy:
self.send_response(503) return web.json_response(
self.end_headers() {"detail": {"msg": "Server is busy; please try again later.", "type": "service_unavailable"}},
self.wfile.write(json.dumps({"detail": { status=503,
"msg": "Server is busy; please try again later.", )
"type": "service_unavailable",
}}).encode())
return
if self.path.endswith('/request'): if path.endswith('/request'):
basic_api_flag = True basic_api_flag = True
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): if path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True kai_api_flag = True
if path.endswith('/api/v1/generate/stream'):
kai_api_flag = True
kai_stream_flag = True
if basic_api_flag or kai_api_flag: if basic_api_flag or kai_api_flag:
genparams = None genparams = None
try: try:
genparams = json.loads(body) genparams = json.loads(body)
except ValueError as e: except ValueError as e:
self.send_response(503) return web.Response(status=503)
self.end_headers()
return
utfprint("\nInput: " + json.dumps(genparams)) utfprint("\nInput: " + json.dumps(genparams))
modelbusy = True modelbusy = True
@ -320,115 +362,41 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
fullprompt = genparams.get('text', "") fullprompt = genparams.get('text', "")
newprompt = fullprompt newprompt = fullprompt
recvtxt = "" await self.handle_request(request, genparams, newprompt, kai_stream_flag)
res = {}
if kai_api_flag:
recvtxt = generate(
prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
utfprint("\nOutput: " + recvtxt)
res = {"results": [{"text": recvtxt}]}
else:
recvtxt = generate(
prompt=newprompt,
max_length=genparams.get('max', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
utfprint("\nOutput: " + recvtxt)
res = {"data": {"seqs":[recvtxt]}}
try:
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps(res).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
modelbusy = False modelbusy = False
return return web.Response()
self.send_response(404)
self.end_headers()
def do_OPTIONS(self): return web.Response(status=404)
self.send_response(200)
self.end_headers()
def do_HEAD(self): async def handle_options(self):
self.send_response(200) return web.Response()
self.end_headers()
def end_headers(self): async def handle_head(self):
self.send_header('Access-Control-Allow-Origin', '*') return web.Response()
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
if "/api" in self.path:
self.send_header('Content-type', 'application/json')
else:
self.send_header('Content-type', 'text/html')
return super(ServerRequestHandler, self).end_headers() async def start_server(self):
self.app.router.add_route('GET', '/{tail:.*}', self.handle_get)
self.app.router.add_route('POST', '/{tail:.*}', self.handle_post)
self.app.router.add_route('OPTIONS', '/', self.handle_options)
self.app.router.add_route('HEAD', '/', self.handle_head)
def RunServerMultiThreaded(addr, port, embedded_kailite = None): runner = web.AppRunner(self.app)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) await runner.setup()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) site = web.TCPSite(runner, self.addr, self.port)
sock.bind((addr, port)) await site.start()
sock.listen(5)
class Thread(threading.Thread): # Keep Alive
def __init__(self, i):
threading.Thread.__init__(self)
self.i = i
self.daemon = True
self.start()
def run(self):
handler = ServerRequestHandler(addr, port, embedded_kailite)
with http.server.HTTPServer((addr, port), handler, False) as self.httpd:
try:
self.httpd.socket = sock
self.httpd.server_bind = self.server_close = lambda self: None
self.httpd.serve_forever()
except (KeyboardInterrupt,SystemExit):
self.httpd.server_close()
sys.exit(0)
finally:
self.httpd.server_close()
sys.exit(0)
def stop(self):
self.httpd.server_close()
numThreads = 6
threadArr = []
for i in range(numThreads):
threadArr.append(Thread(i))
while 1:
try: try:
time.sleep(10) while True:
await asyncio.sleep(3600)
except KeyboardInterrupt: except KeyboardInterrupt:
for i in range(numThreads): await runner.cleanup()
threadArr[i].stop()
sys.exit(0) async def run_server(addr, port, embedded_kailite=None):
handler = ServerRequestHandler(addr, port, embedded_kailite)
await handler.start_server()
def show_gui(): def show_gui():
@ -500,15 +468,15 @@ def show_gui():
unbantokens = tk.IntVar() unbantokens = tk.IntVar()
highpriority = tk.IntVar() highpriority = tk.IntVar()
disablemmap = tk.IntVar() disablemmap = tk.IntVar()
frm3 = tk.Frame(root)
tk.Checkbutton(frm3, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0)
tk.Checkbutton(frm3, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1)
tk.Checkbutton(frm3, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0)
tk.Checkbutton(frm3, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1)
tk.Checkbutton(frm3, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0)
tk.Checkbutton(frm3, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1)
frameD = tk.Frame(root) frm3.grid(row=5,column=0,pady=4)
tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0)
tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1)
tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0)
tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1)
tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0)
tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1)
frameD.grid(row=5,column=0,pady=4)
# Create button, it will change label text # Create button, it will change label text
tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0) tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0)
@ -688,7 +656,7 @@ def main(args):
except: except:
print("--launch was set, but could not launch web browser automatically.") print("--launch was set, but could not launch web browser automatically.")
print(f"Please connect to custom endpoint at {epurl}") print(f"Please connect to custom endpoint at {epurl}")
RunServerMultiThreaded(args.host, args.port, embedded_kailite) asyncio.run(run_server(args.host, args.port, embedded_kailite))
if __name__ == '__main__': if __name__ == '__main__':
print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually