Implement basic chat/completions openai endpoint (#461)

* Implement basic chat/completions openai endpoint

-Basic support for openai chat/completions endpoint documented at: https://platform.openai.com/docs/api-reference/chat/create

-Tested with example code from openai for chat/completions and chat/completions with stream=True parameter found here: https://cookbook.openai.com/examples/how_to_stream_completions.

-Tested with Mantella, the skyrim mod that turns all the NPC's into AI chattable characters, which uses openai's acreate / async competions method: https://github.com/art-from-the-machine/Mantella/blob/main/src/output_manager.py

-Tested default koboldcpp api behavior with streaming and non-streaming generate endpoints and running GUI and seems to be fine.

-Still TODO / evaluate before merging:

(1) implement rest of openai chat/completion parameters to the extent possible, mapping to koboldcpp parameters

(2) determine if there is a way to use kobold's prompt formats for certain models when translating openai messages format into a prompt string. (Not sure if possible or where these are in the code)

(3) have chat/completions responses include the actual local model the user is using instead of just koboldcpp (Not sure if this is possible)

Note I am a python noob, so if there is a more elegant way of doing this at minimum hopefully I have done some of the grunt work for you to implement on your own.

* Fix typographical error on deleted streaming argument

-Mistakenly left code relating to streaming argument from main branch in experimental.

* add additional openai chat completions parameters

-support stop parameter mapped to koboldai stop_sequence parameter

-make default max_length / max_tokens parameter consistent with default 80 token length in generate function

-add support for providing name of local model in openai responses

* Revert "add additional openai chat completions parameters"

This reverts commit 443a6f7ff6346f41c78b0a6ff59c063999542327.

* add additional openai chat completions parameters

-support stop parameter mapped to koboldai stop_sequence parameter

-make default max_length / max_tokens parameter consistent with default 80 token length in generate function

-add support for providing name of local model in openai responses

* add /n after formatting prompts from openaiformat

to conform with alpaca standard used as default in lite.koboldai.net

* tidy up and simplify code, do not set globals for streaming

* oai endpoints must start with v1

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
teddybear082 2023-10-05 08:13:10 -04:00 committed by GitHub
parent 5beb773320
commit f9f4cdf3c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -391,17 +391,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
pass
async def generate_text(self, genparams, api_format, stream_flag):
global friendlymodelname
def run_blocking():
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"]=genparams.get('max', 50)
genparams["max_length"] = genparams.get('max', 80)
elif api_format==3:
frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 50)
genparams["max_length"] = genparams.get('max_tokens', 80)
genparams["rep_pen"] = scaled_rep_pen
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])
else:
genparams["stop_sequence"] = [genparams.get('stop')]
elif api_format==4:
# translate openai chat completion messages format into one big string.
messages_array = genparams.get('messages', [])
messages_string = ""
for message in messages_array:
if message['role'] == "system":
messages_string+="\n### Instruction:\n"
elif message['role'] == "user":
messages_string+="\n### Instruction:\n"
elif message['role'] == "assistant":
messages_string+="\n### Response:\n"
messages_string+=message['content']
messages_string += "\n### Response:\n"
genparams["prompt"] = messages_string
frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 80)
genparams["rep_pen"] = scaled_rep_pen
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])
else:
genparams["stop_sequence"] = [genparams.get('stop')]
return generate(
prompt=genparams.get('prompt', ""),
@ -441,8 +469,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if api_format==1:
res = {"data": {"seqs":[recvtxt]}}
elif api_format==3:
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp",
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
elif api_format==4:
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": "length"}]}
else:
res = {"results": [{"text": recvtxt}]}
@ -452,19 +483,21 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
print(f"Generate: Error while generating: {e}")
async def send_sse_event(self, event, data):
self.wfile.write(f'event: {event}\n'.encode())
async def send_oai_sse_event(self, data):
self.wfile.write(f'data: {data}\r\n\r\n'.encode())
async def send_kai_sse_event(self, data):
self.wfile.write(f'event: message\n'.encode())
self.wfile.write(f'data: {data}\n\n'.encode())
async def handle_sse_stream(self):
async def handle_sse_stream(self, api_format):
global friendlymodelname
self.send_response(200)
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.end_headers()
self.end_headers(force_json=True, sse_stream_flag=True)
current_token = 0
incomplete_token_buffer = bytearray()
while True:
streamDone = handle.has_finished() #exit next loop on done
@ -485,14 +518,20 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
tokenStr += tokenSeg
if tokenStr!="":
event_data = {"token": tokenStr}
event_str = json.dumps(event_data)
if api_format == 4: # if oai chat, set format to expected openai streaming response
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
await self.send_oai_sse_event(event_str)
else:
event_str = json.dumps({"token": tokenStr})
await self.send_kai_sse_event(event_str)
tokenStr = ""
await self.send_sse_event("message", event_str)
else:
await asyncio.sleep(0.02) #this should keep things responsive
if streamDone:
if api_format == 4: # if oai chat, send last [DONE] message consistent with openai format
await self.send_oai_sse_event('[DONE]')
break
# flush buffers, sleep a bit to make sure all data sent, and then force close the connection
@ -505,7 +544,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
tasks = []
if stream_flag:
tasks.append(self.handle_sse_stream())
tasks.append(self.handle_sse_stream(api_format))
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
tasks.append(generate_task)
@ -569,8 +608,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
elif self.path.endswith('/v1/models') or self.path.endswith('/models'):
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
elif self.path.endswith('/v1/models'):
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
force_json = True
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
@ -595,7 +634,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
body = self.rfile.read(content_length)
self.path = self.path.rstrip('/')
force_json = False
if self.path.endswith(('/api/extra/tokencount')):
try:
genparams = json.loads(body)
@ -658,9 +696,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
try:
kai_sse_stream_flag = False
sse_stream_flag = False
api_format = 0 #1=basic,2=kai,3=oai
api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat
if self.path.endswith('/request'):
api_format = 1
@ -670,12 +708,16 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if self.path.endswith('/api/extra/generate/stream'):
api_format = 2
kai_sse_stream_flag = True
sse_stream_flag = True
if self.path.endswith('/v1/completions') or self.path.endswith('/completions'):
if self.path.endswith('/v1/completions'):
api_format = 3
force_json = True
if self.path.endswith('/v1/chat/completions'):
api_format = 4
force_json = True
if api_format>0:
genparams = None
try:
@ -690,11 +732,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if args.foreground:
bring_terminal_to_foreground()
gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag))
# Check if streaming chat completions, if so, set stream mode to true
if api_format == 4 and "stream" in genparams and genparams["stream"]:
sse_stream_flag = True
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
try:
# Headers are already sent when streaming
if not kai_sse_stream_flag:
if not sse_stream_flag:
self.send_response(200)
self.end_headers(force_json=force_json)
self.wfile.write(json.dumps(gen).encode())
@ -717,12 +763,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.send_response(200)
self.end_headers()
def end_headers(self, force_json=False):
def end_headers(self, force_json=False, sse_stream_flag=False):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
if "/api" in self.path or force_json:
if self.path.endswith("/stream"):
if sse_stream_flag:
self.send_header('Content-type', 'text/event-stream')
self.send_header('Content-type', 'application/json')
else:
@ -919,10 +965,12 @@ def show_new_gui():
x, y = root.winfo_pointerxy()
tooltip.wm_geometry(f"+{x + 10}+{y + 10}")
tooltip.deiconify()
def hide_tooltip(event):
if hasattr(show_tooltip, "_tooltip"):
tooltip = show_tooltip._tooltip
tooltip.withdraw()
def setup_backend_tooltip(parent):
num_backends_built = makelabel(parent, str(len(runopts)) + "/6", 5, 2)
num_backends_built.grid(row=1, column=2, padx=0, pady=0)
@ -1097,7 +1145,6 @@ def show_new_gui():
for idx, name, in enumerate(token_boxes):
makecheckbox(tokens_tab, name, token_boxes[name], idx + 1)
# context size
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=2)