Merge remote-tracking branch 'elsagranger/master'

This commit is contained in:
Laura 2023-08-10 07:59:55 +02:00
commit 2c8e92044e
2 changed files with 62 additions and 22 deletions

View file

@ -197,6 +197,14 @@ bash chat.sh
API example using Python Flask: [api_like_OAI.py](api_like_OAI.py) API example using Python Flask: [api_like_OAI.py](api_like_OAI.py)
This example must be used with server.cpp This example must be used with server.cpp
requirements:
```shell
pip install flask flask-cors fschat # flask-cors and fschat are optional. flask-cors is used to allow cross-origin requests, fschat is used for integration of chat template
```
Run the server:
```sh ```sh
python api_like_OAI.py python api_like_OAI.py
``` ```
@ -206,6 +214,8 @@ After running the API server, you can use it in Python by setting the API base U
openai.api_base = "http://<Your api-server IP>:port" openai.api_base = "http://<Your api-server IP>:port"
``` ```
For better integration with the model, it is recommended to utilize the `--chat-prompt-model` parameter when starting up the system, rather than relying solely on parameters like `--user-name`. This specific parameter accepts model names that have been registered within the [FastChat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) file, an example would be `llama-2`.
Then you can utilize llama.cpp as an OpenAI's **chat.completion** or **text_completion** API Then you can utilize llama.cpp as an OpenAI's **chat.completion** or **text_completion** API
### Extending or building alternative Web Front End ### Extending or building alternative Web Front End

View file

@ -4,11 +4,20 @@ import urllib.parse
import requests import requests
import time import time
import json import json
try:
from fastchat import conversation
except ImportError:
conversation = None
app = Flask(__name__) app = Flask(__name__)
try:
from flask_cors import CORS
CORS(app)
except ImportError:
pass
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
parser.add_argument("--chat-prompt-model", type=str, help="Set the model name of conversation template", default="")
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
@ -29,25 +38,46 @@ def is_present(json, key):
return True return True
use_conversation_template = args.chat_prompt_model != "" and conversation is not None
if use_conversation_template:
conv = conversation.get_conv_template(args.chat_prompt_model)
stop_token = conv.stop_str
else:
stop_token = args.stop
#convert chat to prompt #convert chat to prompt
def convert_chat(messages): def convert_chat(messages):
prompt = "" + args.chat_prompt.replace("\\n", "\n") if use_conversation_template:
conv = conversation.get_conv_template(args.chat_prompt_model)
for line in messages:
if (line["role"] == "system"):
try:
conv.set_system_message(line["content"])
except Exception:
pass
elif (line["role"] == "user"):
conv.append_message(conv.roles[0], line["content"])
elif (line["role"] == "assistant"):
conv.append_message(conv.roles[1], line["content"])
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
else:
prompt = "" + args.chat_prompt.replace("\\n", "\n")
system_n = args.system_name.replace("\\n", "\n")
user_n = args.user_name.replace("\\n", "\n")
ai_n = args.ai_name.replace("\\n", "\n")
stop = stop_token.replace("\\n", "\n")
system_n = args.system_name.replace("\\n", "\n") for line in messages:
user_n = args.user_name.replace("\\n", "\n") if (line["role"] == "system"):
ai_n = args.ai_name.replace("\\n", "\n") prompt += f"{system_n}{line['content']}"
stop = args.stop.replace("\\n", "\n") if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
for line in messages: prompt += f"{ai_n}{line['content']}{stop}"
if (line["role"] == "system"): prompt += ai_n.rstrip()
prompt += f"{system_n}{line['content']}"
if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()
return prompt return prompt
@ -69,11 +99,11 @@ def make_postData(body, chat=False, stream=False):
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"] if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
if(is_present(body, "seed")): postData["seed"] = body["seed"] if(is_present(body, "seed")): postData["seed"] = body["seed"]
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
if (args.stop != ""): if stop_token: # "" or None
postData["stop"] = [args.stop] postData["stop"] = [stop_token]
else: else:
postData["stop"] = [] postData["stop"] = []
if(is_present(body, "stop")): postData["stop"] += body["stop"] if(is_present(body, "stop")): postData["stop"] += body["stop"] or []
postData["n_keep"] = -1 postData["n_keep"] = -1
postData["stream"] = stream postData["stream"] = stream
@ -173,12 +203,12 @@ def chat_completions():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time()) time_now = int(time.time())
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
yield 'data: {}\n'.format(json.dumps(resData)) yield 'data: {}\n\n'.format(json.dumps(resData))
for line in data.iter_lines(): for line in data.iter_lines():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData)) yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
@ -212,7 +242,7 @@ def completion():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData)) yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
if __name__ == '__main__': if __name__ == '__main__':