diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index 313e1a965..a67cb9f27 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -72,7 +72,7 @@ def make_postData(body, chat=False, stream=False): if(is_present(body, "seed")): postData["seed"] = body["seed"] if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] if (args.stop != ""): - postData["stop"] = [args.stop] + postData["stop"] = [args.stop, args.ai_name.replace("\\n", "\n"), args.user_name.replace("\\n", "\n")] else: postData["stop"] = [] if(is_present(body, "stop")): postData["stop"] += body["stop"] @@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): } ] } - slot_id = data["slot_id"] + slot_id = data.get("slot_id") if (chat): if (start): resData["choices"][0]["delta"] = { @@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): return resData -@app.route('/chat/completions', methods=['POST']) -@app.route('/v1/chat/completions', methods=['POST']) +@app.route('/chat/completions', methods=['POST', 'OPTIONS']) +@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS']) def chat_completions(): if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): return Response(status=403) + if request.method == 'OPTIONS': + return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) body = request.get_json() stream = False tokenize = False @@ -183,14 +185,16 @@ def chat_completions(): decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) yield 'data: {}\n'.format(json.dumps(resData)) - return Response(generate(), mimetype='text/event-stream') + return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) -@app.route('/completions', methods=['POST']) -@app.route('/v1/completions', methods=['POST']) +@app.route('/completions', methods=['POST', 'OPTIONS']) +@app.route('/v1/completions', methods=['POST', 'OPTIONS']) def completion(): if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): return Response(status=403) + if request.method == 'OPTIONS': + return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) body = request.get_json() stream = False tokenize = False @@ -217,7 +221,7 @@ def completion(): decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) yield 'data: {}\n'.format(json.dumps(resData)) - return Response(generate(), mimetype='text/event-stream') + return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) if __name__ == '__main__': app.run(args.host, port=args.port)