diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index aa325a03e..3f166609c 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -1,14 +1,18 @@ import argparse from flask import Flask, jsonify, request, Response +from flask_cors import CORS import urllib.parse import requests import time import json +from fastchat import conversation app = Flask(__name__) +CORS(app) parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") +parser.add_argument("--chat-prompt-model", type=str, help="Set the model", default="") parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") @@ -29,25 +33,46 @@ def is_present(json, key): return True +use_conversation_template = args.chat_prompt_model != "" + +if use_conversation_template: + conv = conversation.get_conv_template(args.chat_prompt_model) + stop_token = conv.stop_str +else: + stop_token = args.stop + #convert chat to prompt def convert_chat(messages): - prompt = "" + args.chat_prompt.replace("\\n", "\n") + if use_conversation_template: + conv = conversation.get_conv_template(args.chat_prompt_model) + for line in messages: + if (line["role"] == "system"): + try: + conv.set_system_msg(line["content"]) + except Exception: + pass + elif (line["role"] == "user"): + conv.append_message(conv.roles[0], line["content"]) + elif (line["role"] == "assistant"): + conv.append_message(conv.roles[1], line["content"]) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + else: + prompt = "" + args.chat_prompt.replace("\\n", "\n") + system_n = args.system_name.replace("\\n", "\n") + user_n = args.user_name.replace("\\n", "\n") + ai_n = args.ai_name.replace("\\n", "\n") + stop = stop_token.replace("\\n", "\n") - system_n = args.system_name.replace("\\n", "\n") - user_n = args.user_name.replace("\\n", "\n") - ai_n = args.ai_name.replace("\\n", "\n") - stop = args.stop.replace("\\n", "\n") - - - for line in messages: - if (line["role"] == "system"): - prompt += f"{system_n}{line['content']}" - if (line["role"] == "user"): - prompt += f"{user_n}{line['content']}" - if (line["role"] == "assistant"): - prompt += f"{ai_n}{line['content']}{stop}" - prompt += ai_n.rstrip() + for line in messages: + if (line["role"] == "system"): + prompt += f"{system_n}{line['content']}" + if (line["role"] == "user"): + prompt += f"{user_n}{line['content']}" + if (line["role"] == "assistant"): + prompt += f"{ai_n}{line['content']}{stop}" + prompt += ai_n.rstrip() return prompt @@ -69,8 +94,8 @@ def make_postData(body, chat=False, stream=False): if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"] if(is_present(body, "seed")): postData["seed"] = body["seed"] if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] - if (args.stop != ""): - postData["stop"] = [args.stop] + if stop_token: # "" or None + postData["stop"] = [stop_token] else: postData["stop"] = [] if(is_present(body, "stop")): postData["stop"] += body["stop"] @@ -173,12 +198,12 @@ def chat_completions(): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) time_now = int(time.time()) resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) - yield 'data: {}\n'.format(json.dumps(resData)) + yield 'data: {}\n\n'.format(json.dumps(resData)) for line in data.iter_lines(): if line: decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) - yield 'data: {}\n'.format(json.dumps(resData)) + yield 'data: {}\n\n'.format(json.dumps(resData)) return Response(generate(), mimetype='text/event-stream') @@ -212,7 +237,7 @@ def completion(): if line: decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) - yield 'data: {}\n'.format(json.dumps(resData)) + yield 'data: {}\n\n'.format(json.dumps(resData)) return Response(generate(), mimetype='text/event-stream') if __name__ == '__main__':