diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index ed19237b0..0ce383af6 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -52,12 +52,110 @@ def convert_chat(messages): return prompt +def make_grammar(schema, root): + def format_rulename(name): + return name.replace("_", "-") + + def schema_typename(schema, defs, arrs): + typename = '"null"' + if 'type' in schema: + typename = schema['type'] + if '$ref' in schema and schema['$ref'] in defs: + typename = defs[schema['$ref']] + if typename == 'array': + elemtype = schema_typename(schema['items'], defs, arrs) + typename = format_rulename(f'array-{elemtype}') + if typename not in arrs: + arrs[typename] = elemtype + return typename + + def arr_to_rule(name, elemtype): + rule = f'{name} ::= "[" ( {elemtype} ( "," {elemtype} )* )? "]"' + return rule + + def enum_to_rule(name, schema): + enum_values = schema['enum'] + etype = schema['type'] + def value_pattern(value): + if etype == 'string': + return f'"\\"{repr(value)[1:-1]}\\""' + else: + return repr(value) + values = ' | '.join([ + value_pattern(value) + for value in enum_values + ]) + rule = f'{name} ::= ( {values} )' + return rule + + def obj_to_rule(name, schema, defs, arrs): + assert(schema['type'] == 'object') + def propery_to_grammar(name, typename): + return f'"\\"" "{name}" "\\"" ":" {typename}' + properties = '"," '.join([ + propery_to_grammar(name, schema_typename(property, defs, arrs)) + for name, property in schema['properties'].items() + ]) + rule = f'{name} ::= "{{" {properties} "}}"' + return rule + + def model_grammar(schema, root = None): + rules = [] + defs = {} + arrs = {} + if '$defs' in schema: + for name, _def in schema['$defs'].items(): + defs['#/$defs/' + name] = format_rulename(name) + + for name, _def in schema['$defs'].items(): + if 'enum' in _def: + rules.append(enum_to_rule(format_rulename(name), _def)) + elif _def['type'] == 'object': + rules.append(obj_to_rule(format_rulename(name), _def, defs, arrs)) + + if root is None: + root = schema["title"] + root = format_rulename(root) + + if schema['type'] == 'object': + rules.append(obj_to_rule(root, schema, defs, arrs)) + + for arrtype, elemtype in arrs.items(): + rules.append(arr_to_rule(arrtype, elemtype)) + rules.append(f'root ::= {root}') + grammar = "\n".join(rules) + grammar += ( # json base types +(r''' +ws ::= [ \t\n]? +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" +bool ::= "True" | "False" +integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? +''')) + return grammar + return model_grammar(schema, root) + def make_postData(body, chat=False, stream=False): postData = {} if (chat): postData["prompt"] = convert_chat(body["messages"]) else: postData["prompt"] = body["prompt"] + + if(is_present(body, "function_call") and is_present(body["function_call"], "name")): + assert(is_present(body, "functions")) + functions = {} + for function in body["functions"]: + functions[function['name']] = function['parameters'] + function_call = body["function_call"]["name"] + postData["grammar"] = make_grammar(functions[function_call], function_call) + print("grammar") + print(postData["grammar"]) + if(is_present(body, "temperature")): postData["temperature"] = body["temperature"] if(is_present(body, "top_k")): postData["top_k"] = body["top_k"] if(is_present(body, "top_p")): postData["top_p"] = body["top_p"] @@ -80,7 +178,7 @@ def make_postData(body, chat=False, stream=False): return postData -def make_resData(data, chat=False, promptToken=[]): +def make_resData(data, chat=False, promptToken=[], function_call={}): resData = { "id": "chatcmpl" if (chat) else "cmpl", "object": "chat.completion" if (chat) else "text_completion", @@ -95,7 +193,21 @@ def make_resData(data, chat=False, promptToken=[]): } if (len(promptToken) != 0): resData["promptToken"] = promptToken - if (chat): + + if chat and is_present(requestBody, "function_call") and is_present(function_call, "name"): + resData["choices"][0]["delta"] = [{ + "index": 0, + "function_call": { + "name": function_call["name"], + "arguments": "" + }, + "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" + }] + if start: + resData["choices"][0]["delta"]["role"] = "assistant" + if is_present(data, "content"): + resData["choices"][0]["delta"]["function_call"]["arguments"] = data["content"] + elif (chat): #only one choice is supported resData["choices"] = [{ "index": 0, @@ -115,7 +227,7 @@ def make_resData(data, chat=False, promptToken=[]): }] return resData -def make_resData_stream(data, chat=False, time_now = 0, start=False): +def make_resData_stream(data, chat=False, time_now = 0, start=False, function_call={}): resData = { "id": "chatcmpl" if (chat) else "cmpl", "object": "chat.completion.chunk" if (chat) else "text_completion.chunk", @@ -129,7 +241,20 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): ] } if (chat): - if (start): + if is_present(function_call, "name"): + resData["choices"][0]["delta"] = { + "function_call": { + "name": function_call["name"], + "arguments" : "" + } + } + if start: + resData["choices"][0]["delta"]["role"] = "assistant" + if is_present(data, "content"): + resData["choices"][0]["delta"]["function_call"]["arguments"] = data["content"] + if is_present(data, "stop") and data["stop"]: + resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" + elif (start): resData["choices"][0]["delta"] = { "role": "assistant" } @@ -167,18 +292,18 @@ def chat_completions(): if (not stream): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) print(data.json()) - resData = make_resData(data.json(), chat=True, promptToken=promptToken) + resData = make_resData(data.json(), chat=True, promptToken=promptToken, function_call=body.get("function_call")) return jsonify(resData) else: def generate(): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) time_now = int(time.time()) - resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) + resData = make_resData_stream({}, chat=True, time_now=time_now, start=True, function_call=body.get("function_call")) yield 'data: {}\n'.format(json.dumps(resData)) for line in data.iter_lines(): if line: decoded_line = line.decode('utf-8') - resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) + resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now, function_call=body.get("function_call")) yield 'data: {}\n'.format(json.dumps(resData)) return Response(generate(), mimetype='text/event-stream')