add basic support for function calls
when functions and function_call is specified in chat completion requests it generates and uses a grammar for the json scheme given in functions[function_call]
This commit is contained in:
parent
f5ef5cfb18
commit
e950411b8b
1 changed files with 132 additions and 7 deletions
|
@ -52,12 +52,110 @@ def convert_chat(messages):
|
|||
|
||||
return prompt
|
||||
|
||||
def make_grammar(schema, root):
|
||||
def format_rulename(name):
|
||||
return name.replace("_", "-")
|
||||
|
||||
def schema_typename(schema, defs, arrs):
|
||||
typename = '"null"'
|
||||
if 'type' in schema:
|
||||
typename = schema['type']
|
||||
if '$ref' in schema and schema['$ref'] in defs:
|
||||
typename = defs[schema['$ref']]
|
||||
if typename == 'array':
|
||||
elemtype = schema_typename(schema['items'], defs, arrs)
|
||||
typename = format_rulename(f'array-{elemtype}')
|
||||
if typename not in arrs:
|
||||
arrs[typename] = elemtype
|
||||
return typename
|
||||
|
||||
def arr_to_rule(name, elemtype):
|
||||
rule = f'{name} ::= "[" ( {elemtype} ( "," {elemtype} )* )? "]"'
|
||||
return rule
|
||||
|
||||
def enum_to_rule(name, schema):
|
||||
enum_values = schema['enum']
|
||||
etype = schema['type']
|
||||
def value_pattern(value):
|
||||
if etype == 'string':
|
||||
return f'"\\"{repr(value)[1:-1]}\\""'
|
||||
else:
|
||||
return repr(value)
|
||||
values = ' | '.join([
|
||||
value_pattern(value)
|
||||
for value in enum_values
|
||||
])
|
||||
rule = f'{name} ::= ( {values} )'
|
||||
return rule
|
||||
|
||||
def obj_to_rule(name, schema, defs, arrs):
|
||||
assert(schema['type'] == 'object')
|
||||
def propery_to_grammar(name, typename):
|
||||
return f'"\\"" "{name}" "\\"" ":" {typename}'
|
||||
properties = '"," '.join([
|
||||
propery_to_grammar(name, schema_typename(property, defs, arrs))
|
||||
for name, property in schema['properties'].items()
|
||||
])
|
||||
rule = f'{name} ::= "{{" {properties} "}}"'
|
||||
return rule
|
||||
|
||||
def model_grammar(schema, root = None):
|
||||
rules = []
|
||||
defs = {}
|
||||
arrs = {}
|
||||
if '$defs' in schema:
|
||||
for name, _def in schema['$defs'].items():
|
||||
defs['#/$defs/' + name] = format_rulename(name)
|
||||
|
||||
for name, _def in schema['$defs'].items():
|
||||
if 'enum' in _def:
|
||||
rules.append(enum_to_rule(format_rulename(name), _def))
|
||||
elif _def['type'] == 'object':
|
||||
rules.append(obj_to_rule(format_rulename(name), _def, defs, arrs))
|
||||
|
||||
if root is None:
|
||||
root = schema["title"]
|
||||
root = format_rulename(root)
|
||||
|
||||
if schema['type'] == 'object':
|
||||
rules.append(obj_to_rule(root, schema, defs, arrs))
|
||||
|
||||
for arrtype, elemtype in arrs.items():
|
||||
rules.append(arr_to_rule(arrtype, elemtype))
|
||||
rules.append(f'root ::= {root}')
|
||||
grammar = "\n".join(rules)
|
||||
grammar += ( # json base types
|
||||
(r'''
|
||||
ws ::= [ \t\n]?
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\""
|
||||
bool ::= "True" | "False"
|
||||
integer ::= ("-"? ([0-9] | [1-9] [0-9]*))
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)?
|
||||
'''))
|
||||
return grammar
|
||||
return model_grammar(schema, root)
|
||||
|
||||
def make_postData(body, chat=False, stream=False):
|
||||
postData = {}
|
||||
if (chat):
|
||||
postData["prompt"] = convert_chat(body["messages"])
|
||||
else:
|
||||
postData["prompt"] = body["prompt"]
|
||||
|
||||
if(is_present(body, "function_call") and is_present(body["function_call"], "name")):
|
||||
assert(is_present(body, "functions"))
|
||||
functions = {}
|
||||
for function in body["functions"]:
|
||||
functions[function['name']] = function['parameters']
|
||||
function_call = body["function_call"]["name"]
|
||||
postData["grammar"] = make_grammar(functions[function_call], function_call)
|
||||
print("grammar")
|
||||
print(postData["grammar"])
|
||||
|
||||
if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
|
||||
if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
|
||||
if(is_present(body, "top_p")): postData["top_p"] = body["top_p"]
|
||||
|
@ -80,7 +178,7 @@ def make_postData(body, chat=False, stream=False):
|
|||
|
||||
return postData
|
||||
|
||||
def make_resData(data, chat=False, promptToken=[]):
|
||||
def make_resData(data, chat=False, promptToken=[], function_call={}):
|
||||
resData = {
|
||||
"id": "chatcmpl" if (chat) else "cmpl",
|
||||
"object": "chat.completion" if (chat) else "text_completion",
|
||||
|
@ -95,7 +193,21 @@ def make_resData(data, chat=False, promptToken=[]):
|
|||
}
|
||||
if (len(promptToken) != 0):
|
||||
resData["promptToken"] = promptToken
|
||||
if (chat):
|
||||
|
||||
if chat and is_present(requestBody, "function_call") and is_present(function_call, "name"):
|
||||
resData["choices"][0]["delta"] = [{
|
||||
"index": 0,
|
||||
"function_call": {
|
||||
"name": function_call["name"],
|
||||
"arguments": ""
|
||||
},
|
||||
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
|
||||
}]
|
||||
if start:
|
||||
resData["choices"][0]["delta"]["role"] = "assistant"
|
||||
if is_present(data, "content"):
|
||||
resData["choices"][0]["delta"]["function_call"]["arguments"] = data["content"]
|
||||
elif (chat):
|
||||
#only one choice is supported
|
||||
resData["choices"] = [{
|
||||
"index": 0,
|
||||
|
@ -115,7 +227,7 @@ def make_resData(data, chat=False, promptToken=[]):
|
|||
}]
|
||||
return resData
|
||||
|
||||
def make_resData_stream(data, chat=False, time_now = 0, start=False):
|
||||
def make_resData_stream(data, chat=False, time_now = 0, start=False, function_call={}):
|
||||
resData = {
|
||||
"id": "chatcmpl" if (chat) else "cmpl",
|
||||
"object": "chat.completion.chunk" if (chat) else "text_completion.chunk",
|
||||
|
@ -129,7 +241,20 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
|
|||
]
|
||||
}
|
||||
if (chat):
|
||||
if (start):
|
||||
if is_present(function_call, "name"):
|
||||
resData["choices"][0]["delta"] = {
|
||||
"function_call": {
|
||||
"name": function_call["name"],
|
||||
"arguments" : ""
|
||||
}
|
||||
}
|
||||
if start:
|
||||
resData["choices"][0]["delta"]["role"] = "assistant"
|
||||
if is_present(data, "content"):
|
||||
resData["choices"][0]["delta"]["function_call"]["arguments"] = data["content"]
|
||||
if is_present(data, "stop") and data["stop"]:
|
||||
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
|
||||
elif (start):
|
||||
resData["choices"][0]["delta"] = {
|
||||
"role": "assistant"
|
||||
}
|
||||
|
@ -167,18 +292,18 @@ def chat_completions():
|
|||
if (not stream):
|
||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
|
||||
print(data.json())
|
||||
resData = make_resData(data.json(), chat=True, promptToken=promptToken)
|
||||
resData = make_resData(data.json(), chat=True, promptToken=promptToken, function_call=body.get("function_call"))
|
||||
return jsonify(resData)
|
||||
else:
|
||||
def generate():
|
||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
||||
time_now = int(time.time())
|
||||
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
||||
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True, function_call=body.get("function_call"))
|
||||
yield 'data: {}\n'.format(json.dumps(resData))
|
||||
for line in data.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now, function_call=body.get("function_call"))
|
||||
yield 'data: {}\n'.format(json.dumps(resData))
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue