Use coversation template from fastchat for api proxy
Fix eventsource format
This commit is contained in:
parent
eb542d3932
commit
ea5a7fbc95
1 changed files with 45 additions and 20 deletions
|
@ -1,14 +1,18 @@
|
||||||
import argparse
|
import argparse
|
||||||
from flask import Flask, jsonify, request, Response
|
from flask import Flask, jsonify, request, Response
|
||||||
|
from flask_cors import CORS
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
from fastchat import conversation
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
CORS(app)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
|
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
|
||||||
|
parser.add_argument("--chat-prompt-model", type=str, help="Set the model", default="")
|
||||||
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
|
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
|
||||||
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
|
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
|
||||||
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
|
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
|
||||||
|
@ -29,16 +33,37 @@ def is_present(json, key):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
use_conversation_template = args.chat_prompt_model != ""
|
||||||
|
|
||||||
|
if use_conversation_template:
|
||||||
|
conv = conversation.get_conv_template(args.chat_prompt_model)
|
||||||
|
stop_token = conv.stop_str
|
||||||
|
else:
|
||||||
|
stop_token = args.stop
|
||||||
|
|
||||||
|
|
||||||
#convert chat to prompt
|
#convert chat to prompt
|
||||||
def convert_chat(messages):
|
def convert_chat(messages):
|
||||||
|
if use_conversation_template:
|
||||||
|
conv = conversation.get_conv_template(args.chat_prompt_model)
|
||||||
|
for line in messages:
|
||||||
|
if (line["role"] == "system"):
|
||||||
|
try:
|
||||||
|
conv.set_system_msg(line["content"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif (line["role"] == "user"):
|
||||||
|
conv.append_message(conv.roles[0], line["content"])
|
||||||
|
elif (line["role"] == "assistant"):
|
||||||
|
conv.append_message(conv.roles[1], line["content"])
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
prompt = conv.get_prompt()
|
||||||
|
else:
|
||||||
prompt = "" + args.chat_prompt.replace("\\n", "\n")
|
prompt = "" + args.chat_prompt.replace("\\n", "\n")
|
||||||
|
|
||||||
system_n = args.system_name.replace("\\n", "\n")
|
system_n = args.system_name.replace("\\n", "\n")
|
||||||
user_n = args.user_name.replace("\\n", "\n")
|
user_n = args.user_name.replace("\\n", "\n")
|
||||||
ai_n = args.ai_name.replace("\\n", "\n")
|
ai_n = args.ai_name.replace("\\n", "\n")
|
||||||
stop = args.stop.replace("\\n", "\n")
|
stop = stop_token.replace("\\n", "\n")
|
||||||
|
|
||||||
|
|
||||||
for line in messages:
|
for line in messages:
|
||||||
if (line["role"] == "system"):
|
if (line["role"] == "system"):
|
||||||
|
@ -69,8 +94,8 @@ def make_postData(body, chat=False, stream=False):
|
||||||
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
|
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
|
||||||
if(is_present(body, "seed")): postData["seed"] = body["seed"]
|
if(is_present(body, "seed")): postData["seed"] = body["seed"]
|
||||||
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
|
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
|
||||||
if (args.stop != ""):
|
if stop_token: # "" or None
|
||||||
postData["stop"] = [args.stop]
|
postData["stop"] = [stop_token]
|
||||||
else:
|
else:
|
||||||
postData["stop"] = []
|
postData["stop"] = []
|
||||||
if(is_present(body, "stop")): postData["stop"] += body["stop"]
|
if(is_present(body, "stop")): postData["stop"] += body["stop"]
|
||||||
|
@ -173,12 +198,12 @@ def chat_completions():
|
||||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
||||||
time_now = int(time.time())
|
time_now = int(time.time())
|
||||||
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
for line in data.iter_lines():
|
for line in data.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,7 +237,7 @@ def completion():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue