Merge branch 'concedo_experimental' of https://github.com/LostRuins/llamacpp-for-kobold into concedo_experimental
# Conflicts: # koboldcpp.py
This commit is contained in:
commit
a0c1ba7747
1 changed files with 74 additions and 27 deletions
99
koboldcpp.py
99
koboldcpp.py
|
@ -392,17 +392,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def generate_text(self, genparams, api_format, stream_flag):
|
async def generate_text(self, genparams, api_format, stream_flag):
|
||||||
|
global friendlymodelname
|
||||||
def run_blocking():
|
def run_blocking():
|
||||||
if api_format==1:
|
if api_format==1:
|
||||||
genparams["prompt"] = genparams.get('text', "")
|
genparams["prompt"] = genparams.get('text', "")
|
||||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||||
genparams["max_length"]=genparams.get('max', 50)
|
genparams["max_length"] = genparams.get('max', 80)
|
||||||
elif api_format==3:
|
elif api_format==3:
|
||||||
frqp = genparams.get('frequency_penalty', 0.1)
|
frqp = genparams.get('frequency_penalty', 0.1)
|
||||||
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
|
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
|
||||||
genparams["max_length"] = genparams.get('max_tokens', 50)
|
genparams["max_length"] = genparams.get('max_tokens', 80)
|
||||||
genparams["rep_pen"] = scaled_rep_pen
|
genparams["rep_pen"] = scaled_rep_pen
|
||||||
|
# openai allows either a string or a list as a stop sequence
|
||||||
|
if isinstance(genparams.get('stop',[]), list):
|
||||||
|
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||||
|
else:
|
||||||
|
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||||
|
elif api_format==4:
|
||||||
|
# translate openai chat completion messages format into one big string.
|
||||||
|
messages_array = genparams.get('messages', [])
|
||||||
|
messages_string = ""
|
||||||
|
for message in messages_array:
|
||||||
|
if message['role'] == "system":
|
||||||
|
messages_string+="\n### Instruction:\n"
|
||||||
|
elif message['role'] == "user":
|
||||||
|
messages_string+="\n### Instruction:\n"
|
||||||
|
elif message['role'] == "assistant":
|
||||||
|
messages_string+="\n### Response:\n"
|
||||||
|
messages_string+=message['content']
|
||||||
|
messages_string += "\n### Response:\n"
|
||||||
|
genparams["prompt"] = messages_string
|
||||||
|
frqp = genparams.get('frequency_penalty', 0.1)
|
||||||
|
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
|
||||||
|
genparams["max_length"] = genparams.get('max_tokens', 80)
|
||||||
|
genparams["rep_pen"] = scaled_rep_pen
|
||||||
|
# openai allows either a string or a list as a stop sequence
|
||||||
|
if isinstance(genparams.get('stop',[]), list):
|
||||||
|
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||||
|
else:
|
||||||
|
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||||
|
|
||||||
return generate(
|
return generate(
|
||||||
prompt=genparams.get('prompt', ""),
|
prompt=genparams.get('prompt', ""),
|
||||||
|
@ -442,8 +470,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
if api_format==1:
|
if api_format==1:
|
||||||
res = {"data": {"seqs":[recvtxt]}}
|
res = {"data": {"seqs":[recvtxt]}}
|
||||||
elif api_format==3:
|
elif api_format==3:
|
||||||
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp",
|
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
|
||||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
|
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
|
||||||
|
elif api_format==4:
|
||||||
|
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
|
||||||
|
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": "length"}]}
|
||||||
else:
|
else:
|
||||||
res = {"results": [{"text": recvtxt}]}
|
res = {"results": [{"text": recvtxt}]}
|
||||||
|
|
||||||
|
@ -453,19 +484,21 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
print(f"Generate: Error while generating: {e}")
|
print(f"Generate: Error while generating: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def send_sse_event(self, event, data):
|
async def send_oai_sse_event(self, data):
|
||||||
self.wfile.write(f'event: {event}\n'.encode())
|
self.wfile.write(f'data: {data}\r\n\r\n'.encode())
|
||||||
|
|
||||||
|
async def send_kai_sse_event(self, data):
|
||||||
|
self.wfile.write(f'event: message\n'.encode())
|
||||||
self.wfile.write(f'data: {data}\n\n'.encode())
|
self.wfile.write(f'data: {data}\n\n'.encode())
|
||||||
|
|
||||||
|
async def handle_sse_stream(self, api_format):
|
||||||
async def handle_sse_stream(self):
|
global friendlymodelname
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Cache-Control", "no-cache")
|
self.send_header("Cache-Control", "no-cache")
|
||||||
self.send_header("Connection", "keep-alive")
|
self.send_header("Connection", "keep-alive")
|
||||||
self.end_headers()
|
self.end_headers(force_json=True, sse_stream_flag=True)
|
||||||
|
|
||||||
current_token = 0
|
current_token = 0
|
||||||
|
|
||||||
incomplete_token_buffer = bytearray()
|
incomplete_token_buffer = bytearray()
|
||||||
while True:
|
while True:
|
||||||
streamDone = handle.has_finished() #exit next loop on done
|
streamDone = handle.has_finished() #exit next loop on done
|
||||||
|
@ -486,14 +519,20 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tokenStr += tokenSeg
|
tokenStr += tokenSeg
|
||||||
|
|
||||||
if tokenStr!="":
|
if tokenStr!="":
|
||||||
event_data = {"token": tokenStr}
|
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
||||||
event_str = json.dumps(event_data)
|
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
||||||
|
await self.send_oai_sse_event(event_str)
|
||||||
|
else:
|
||||||
|
event_str = json.dumps({"token": tokenStr})
|
||||||
|
await self.send_kai_sse_event(event_str)
|
||||||
tokenStr = ""
|
tokenStr = ""
|
||||||
await self.send_sse_event("message", event_str)
|
|
||||||
else:
|
else:
|
||||||
await asyncio.sleep(0.02) #this should keep things responsive
|
await asyncio.sleep(0.02) #this should keep things responsive
|
||||||
|
|
||||||
if streamDone:
|
if streamDone:
|
||||||
|
if api_format == 4: # if oai chat, send last [DONE] message consistent with openai format
|
||||||
|
await self.send_oai_sse_event('[DONE]')
|
||||||
break
|
break
|
||||||
|
|
||||||
# flush buffers, sleep a bit to make sure all data sent, and then force close the connection
|
# flush buffers, sleep a bit to make sure all data sent, and then force close the connection
|
||||||
|
@ -506,7 +545,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
if stream_flag:
|
if stream_flag:
|
||||||
tasks.append(self.handle_sse_stream())
|
tasks.append(self.handle_sse_stream(api_format))
|
||||||
|
|
||||||
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
||||||
tasks.append(generate_task)
|
tasks.append(generate_task)
|
||||||
|
@ -570,8 +609,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||||
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||||
|
|
||||||
elif self.path.endswith('/v1/models') or self.path.endswith('/models'):
|
elif self.path.endswith('/v1/models'):
|
||||||
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||||
force_json = True
|
force_json = True
|
||||||
|
|
||||||
elif self.path=="/api":
|
elif self.path=="/api":
|
||||||
|
@ -604,7 +643,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
body = self.rfile.read(content_length)
|
body = self.rfile.read(content_length)
|
||||||
self.path = self.path.rstrip('/')
|
self.path = self.path.rstrip('/')
|
||||||
force_json = False
|
force_json = False
|
||||||
|
|
||||||
if self.path.endswith(('/api/extra/tokencount')):
|
if self.path.endswith(('/api/extra/tokencount')):
|
||||||
try:
|
try:
|
||||||
genparams = json.loads(body)
|
genparams = json.loads(body)
|
||||||
|
@ -667,9 +705,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
|
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kai_sse_stream_flag = False
|
sse_stream_flag = False
|
||||||
|
|
||||||
api_format = 0 #1=basic,2=kai,3=oai
|
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat
|
||||||
|
|
||||||
if self.path.endswith('/request'):
|
if self.path.endswith('/request'):
|
||||||
api_format = 1
|
api_format = 1
|
||||||
|
@ -679,12 +717,16 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
if self.path.endswith('/api/extra/generate/stream'):
|
if self.path.endswith('/api/extra/generate/stream'):
|
||||||
api_format = 2
|
api_format = 2
|
||||||
kai_sse_stream_flag = True
|
sse_stream_flag = True
|
||||||
|
|
||||||
if self.path.endswith('/v1/completions') or self.path.endswith('/completions'):
|
if self.path.endswith('/v1/completions'):
|
||||||
api_format = 3
|
api_format = 3
|
||||||
force_json = True
|
force_json = True
|
||||||
|
|
||||||
|
if self.path.endswith('/v1/chat/completions'):
|
||||||
|
api_format = 4
|
||||||
|
force_json = True
|
||||||
|
|
||||||
if api_format>0:
|
if api_format>0:
|
||||||
genparams = None
|
genparams = None
|
||||||
try:
|
try:
|
||||||
|
@ -699,11 +741,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
if args.foreground:
|
if args.foreground:
|
||||||
bring_terminal_to_foreground()
|
bring_terminal_to_foreground()
|
||||||
|
|
||||||
gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag))
|
# Check if streaming chat completions, if so, set stream mode to true
|
||||||
|
if api_format == 4 and "stream" in genparams and genparams["stream"]:
|
||||||
|
sse_stream_flag = True
|
||||||
|
|
||||||
|
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Headers are already sent when streaming
|
# Headers are already sent when streaming
|
||||||
if not kai_sse_stream_flag:
|
if not sse_stream_flag:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers(force_json=force_json)
|
self.end_headers(force_json=force_json)
|
||||||
self.wfile.write(json.dumps(gen).encode())
|
self.wfile.write(json.dumps(gen).encode())
|
||||||
|
@ -726,12 +772,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
def end_headers(self, force_json=False):
|
def end_headers(self, force_json=False, sse_stream_flag=False):
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
self.send_header('Access-Control-Allow-Origin', '*')
|
||||||
self.send_header('Access-Control-Allow-Methods', '*')
|
self.send_header('Access-Control-Allow-Methods', '*')
|
||||||
self.send_header('Access-Control-Allow-Headers', '*')
|
self.send_header('Access-Control-Allow-Headers', '*')
|
||||||
if ("/api" in self.path and self.path!="/api") or force_json:
|
if ("/api" in self.path and self.path!="/api") or force_json:
|
||||||
if self.path.endswith("/stream"):
|
if sse_stream_flag:
|
||||||
self.send_header('Content-type', 'text/event-stream')
|
self.send_header('Content-type', 'text/event-stream')
|
||||||
self.send_header('Content-type', 'application/json')
|
self.send_header('Content-type', 'application/json')
|
||||||
else:
|
else:
|
||||||
|
@ -928,10 +974,12 @@ def show_new_gui():
|
||||||
x, y = root.winfo_pointerxy()
|
x, y = root.winfo_pointerxy()
|
||||||
tooltip.wm_geometry(f"+{x + 10}+{y + 10}")
|
tooltip.wm_geometry(f"+{x + 10}+{y + 10}")
|
||||||
tooltip.deiconify()
|
tooltip.deiconify()
|
||||||
|
|
||||||
def hide_tooltip(event):
|
def hide_tooltip(event):
|
||||||
if hasattr(show_tooltip, "_tooltip"):
|
if hasattr(show_tooltip, "_tooltip"):
|
||||||
tooltip = show_tooltip._tooltip
|
tooltip = show_tooltip._tooltip
|
||||||
tooltip.withdraw()
|
tooltip.withdraw()
|
||||||
|
|
||||||
def setup_backend_tooltip(parent):
|
def setup_backend_tooltip(parent):
|
||||||
num_backends_built = makelabel(parent, str(len(runopts)) + "/6", 5, 2)
|
num_backends_built = makelabel(parent, str(len(runopts)) + "/6", 5, 2)
|
||||||
num_backends_built.grid(row=1, column=2, padx=0, pady=0)
|
num_backends_built.grid(row=1, column=2, padx=0, pady=0)
|
||||||
|
@ -1106,7 +1154,6 @@ def show_new_gui():
|
||||||
for idx, name, in enumerate(token_boxes):
|
for idx, name, in enumerate(token_boxes):
|
||||||
makecheckbox(tokens_tab, name, token_boxes[name], idx + 1)
|
makecheckbox(tokens_tab, name, token_boxes[name], idx + 1)
|
||||||
|
|
||||||
|
|
||||||
# context size
|
# context size
|
||||||
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=2)
|
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=2)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue