add multiple functions, decision to send message or use function_call and include sent function results in chat prompt
This commit is contained in:
parent
b6ff08a291
commit
6e81bc5f8b
1 changed files with 165 additions and 50 deletions
|
@ -14,6 +14,7 @@ parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat comp
|
||||||
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
|
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
|
||||||
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
|
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
|
||||||
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
|
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
|
||||||
|
parser.add_argument("--function-name", type=str, help="FUNCTION name in chat completions(default: '\\nFUNCTION: ')", default="\\nFUNCTION: ")
|
||||||
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
|
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
|
||||||
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
|
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
|
||||||
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
|
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
|
||||||
|
@ -25,6 +26,8 @@ args = parser.parse_args()
|
||||||
def is_present(json, key):
|
def is_present(json, key):
|
||||||
try:
|
try:
|
||||||
buf = json[key]
|
buf = json[key]
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
@ -38,6 +41,7 @@ def convert_chat(messages):
|
||||||
system_n = args.system_name.replace("\\n", "\n")
|
system_n = args.system_name.replace("\\n", "\n")
|
||||||
user_n = args.user_name.replace("\\n", "\n")
|
user_n = args.user_name.replace("\\n", "\n")
|
||||||
ai_n = args.ai_name.replace("\\n", "\n")
|
ai_n = args.ai_name.replace("\\n", "\n")
|
||||||
|
fn_n = args.function_name.replace("\\n", "\n")
|
||||||
stop = args.stop.replace("\\n", "\n")
|
stop = args.stop.replace("\\n", "\n")
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,34 +50,52 @@ def convert_chat(messages):
|
||||||
prompt += f"{system_n}{line['content']}"
|
prompt += f"{system_n}{line['content']}"
|
||||||
if (line["role"] == "user"):
|
if (line["role"] == "user"):
|
||||||
prompt += f"{user_n}{line['content']}"
|
prompt += f"{user_n}{line['content']}"
|
||||||
|
if (line["role"] == "function"):
|
||||||
|
prompt += f"{fn_n}{line['content']}"
|
||||||
if (line["role"] == "assistant"):
|
if (line["role"] == "assistant"):
|
||||||
|
# if (not is_present(line, 'content') or line['content'] is None):
|
||||||
|
# if is_present(line, 'function_call'):
|
||||||
|
# # line['content'] = json.dumps({'function_call': line['function_call']})
|
||||||
|
# fname = line['function_call']['name']
|
||||||
|
# line['content'] = f'call {fname}'
|
||||||
|
|
||||||
prompt += f"{ai_n}{line['content']}{stop}"
|
prompt += f"{ai_n}{line['content']}{stop}"
|
||||||
prompt += ai_n.rstrip()
|
prompt += ai_n.rstrip()
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def make_grammar(schema, root):
|
def make_grammar(schema, root):
|
||||||
|
indent_inc = " "
|
||||||
def format_rulename(name):
|
def format_rulename(name):
|
||||||
return name.replace("_", "-")
|
return name.replace("_", "-")
|
||||||
|
|
||||||
def schema_typename(schema, defs, arrs):
|
def is_basetype(typename):
|
||||||
|
return typename in ['integer', 'number', 'bool', 'string']
|
||||||
|
|
||||||
|
def schema_typename(prefix, schema, defs, arrs):
|
||||||
typename = '"null"'
|
typename = '"null"'
|
||||||
if 'type' in schema:
|
if 'type' in schema:
|
||||||
typename = schema['type']
|
typename = schema['type']
|
||||||
|
if not is_basetype(typename) and typename != 'array':
|
||||||
|
typename = prefix + typename
|
||||||
if '$ref' in schema and schema['$ref'] in defs:
|
if '$ref' in schema and schema['$ref'] in defs:
|
||||||
typename = defs[schema['$ref']]
|
typename = defs[schema['$ref']]
|
||||||
if typename == 'array':
|
if typename == 'array':
|
||||||
elemtype = schema_typename(schema['items'], defs, arrs)
|
elemtype = schema_typename(prefix, schema['items'], defs, arrs)
|
||||||
typename = format_rulename(f'array-{elemtype}')
|
typename = f'array-{elemtype}'
|
||||||
|
if not is_basetype(elemtype):
|
||||||
|
typename = prefix + typename
|
||||||
|
typename = format_rulename(typename)
|
||||||
if typename not in arrs:
|
if typename not in arrs:
|
||||||
arrs[typename] = elemtype
|
arrs[typename] = elemtype
|
||||||
return typename
|
return typename
|
||||||
|
|
||||||
def arr_to_rule(name, elemtype):
|
def arr_to_rules(rules, prefix, name, elemtype):
|
||||||
rule = f'{name} ::= "[" ( {elemtype} ( "," {elemtype} )* )? "]"'
|
rulename = name
|
||||||
return rule
|
rulename = format_rulename(rulename)
|
||||||
|
rules[rulename] = f'"[" ( {elemtype} ( "," {elemtype} )* )? "]"'
|
||||||
|
|
||||||
def enum_to_rule(name, schema):
|
def enum_to_rules(rules, prefix, name, schema):
|
||||||
enum_values = schema['enum']
|
enum_values = schema['enum']
|
||||||
etype = schema['type']
|
etype = schema['type']
|
||||||
def value_pattern(value):
|
def value_pattern(value):
|
||||||
|
@ -85,44 +107,106 @@ def make_grammar(schema, root):
|
||||||
value_pattern(value)
|
value_pattern(value)
|
||||||
for value in enum_values
|
for value in enum_values
|
||||||
])
|
])
|
||||||
rule = f'{name} ::= ( {values} )'
|
rulename = format_rulename(f'{prefix}{name}')
|
||||||
return rule
|
rules[rulename] = f'( {values} )'
|
||||||
|
|
||||||
def obj_to_rule(name, schema, defs, arrs):
|
def anyof_to_rules(rules, prefix, name, schema, defs, arrs):
|
||||||
assert(schema['type'] == 'object')
|
values = schema['anyOf']
|
||||||
def propery_to_grammar(name, typename):
|
values = ' | '.join([
|
||||||
return f'"\\"" "{name}" "\\"" ":" {typename}'
|
schema_typename(prefix, value, defs, arrs)
|
||||||
properties = '"," '.join([
|
for value in values
|
||||||
propery_to_grammar(name, schema_typename(property, defs, arrs))
|
|
||||||
for name, property in schema['properties'].items()
|
|
||||||
])
|
])
|
||||||
rule = f'{name} ::= "{{" {properties} "}}"'
|
rulename = format_rulename(f'{prefix}{name}')
|
||||||
return rule
|
rules[rulename] = f'( {values} )'
|
||||||
|
|
||||||
|
def declare_rules(indent, rules, prefix, name, schema, defs, arrs):
|
||||||
|
# print(indent, "declare_rules() prefix", prefix)
|
||||||
|
# print(indent, "declare_rules() name", name)
|
||||||
|
# print(indent, "declare_rules() schema", schema)
|
||||||
|
if 'enum' in schema:
|
||||||
|
enum_to_rules(rules, prefix, format_rulename(name), schema)
|
||||||
|
elif 'anyOf' in schema:
|
||||||
|
anyof_to_rules(rules, prefix, format_rulename(name), schema, defs, arrs)
|
||||||
|
elif schema.get('type', None) == 'object':
|
||||||
|
obj_to_rules(indent + indent_inc, rules, prefix, format_rulename(name), schema, defs, arrs, is_toplevel=False)
|
||||||
|
# else:
|
||||||
|
# print(indent,"warning, did not declare any rules for (prefix,name,schema)", prefix, name, schema)
|
||||||
|
|
||||||
|
def obj_to_rules(indent, rules, prefix, name, schema, defs, arrs, is_toplevel):
|
||||||
|
assert(schema['type'] == 'object')
|
||||||
|
if defs is None:
|
||||||
|
defs = {}
|
||||||
|
if arrs is None:
|
||||||
|
arrs = {}
|
||||||
|
|
||||||
|
|
||||||
|
rulename = name
|
||||||
|
if not is_toplevel:
|
||||||
|
rulename = prefix+rulename
|
||||||
|
rulename = format_rulename(rulename)
|
||||||
|
|
||||||
|
# print(indent, "obj_to_rules() prefix", prefix)
|
||||||
|
# print(indent, "obj_to_rules() name", name)
|
||||||
|
# print(indent, "obj_to_rules() schema", schema)
|
||||||
|
# print(indent, "obj_to_rules() rulename", rulename)
|
||||||
|
|
||||||
def model_grammar(schema, root = None):
|
|
||||||
rules = []
|
|
||||||
defs = {}
|
|
||||||
arrs = {}
|
|
||||||
if '$defs' in schema:
|
if '$defs' in schema:
|
||||||
for name, _def in schema['$defs'].items():
|
for name, _def in schema['$defs'].items():
|
||||||
defs['#/$defs/' + name] = format_rulename(name)
|
defs['#/$defs/' + name] = format_rulename(prefix + name)
|
||||||
|
|
||||||
|
if '$defs' in schema:
|
||||||
for name, _def in schema['$defs'].items():
|
for name, _def in schema['$defs'].items():
|
||||||
if 'enum' in _def:
|
declare_rules(indent + indent_inc, rules, prefix, name, _def, defs, arrs)
|
||||||
rules.append(enum_to_rule(format_rulename(name), _def))
|
|
||||||
elif _def['type'] == 'object':
|
|
||||||
rules.append(obj_to_rule(format_rulename(name), _def, defs, arrs))
|
|
||||||
|
|
||||||
if root is None:
|
for name, prop in schema['properties'].items():
|
||||||
root = schema["title"]
|
declare_rules(indent + indent_inc, rules, prefix, name, prop, defs, arrs)
|
||||||
root = format_rulename(root)
|
|
||||||
|
|
||||||
if schema['type'] == 'object':
|
def propery_to_grammar(name, typename):
|
||||||
rules.append(obj_to_rule(root, schema, defs, arrs))
|
return f'"\\"" "{name}" "\\"" ":" {typename}'
|
||||||
|
|
||||||
|
properties = ' "," '.join([
|
||||||
|
propery_to_grammar(name, schema_typename(prefix, prop, defs, arrs))
|
||||||
|
for name, prop in schema['properties'].items()
|
||||||
|
])
|
||||||
|
# rule = f'{rulename} ::= "{{" {properties} "}}"'
|
||||||
|
rules[rulename] = f'"{{" {properties} "}}"'
|
||||||
|
|
||||||
|
# print("indent", "obj_to_rules() arrs", arrs)
|
||||||
for arrtype, elemtype in arrs.items():
|
for arrtype, elemtype in arrs.items():
|
||||||
rules.append(arr_to_rule(arrtype, elemtype))
|
arr_to_rules(rules, prefix, arrtype, elemtype)
|
||||||
rules.append(f'root ::= {root}')
|
|
||||||
|
return rulename
|
||||||
|
|
||||||
|
def model_grammar(schema, root = None):
|
||||||
|
indent = ""
|
||||||
|
rules = {}
|
||||||
|
defs = {}
|
||||||
|
fns = {}
|
||||||
|
arrs = {}
|
||||||
|
|
||||||
|
# print(indent, "model_grammar() schema")
|
||||||
|
# print(indent, schema)
|
||||||
|
# print(indent, "model_grammar() root", root)
|
||||||
|
|
||||||
|
for fn in schema:
|
||||||
|
name = fn['name']
|
||||||
|
params = fn['parameters']
|
||||||
|
prefix = f"{name}-"
|
||||||
|
fns[name] = obj_to_rules(indent + indent_inc, rules, prefix, name, params, {}, {}, is_toplevel=True)
|
||||||
|
|
||||||
|
# print(indent, "fns")
|
||||||
|
# print(indent, fns)
|
||||||
|
|
||||||
|
# assert("name" in root)
|
||||||
|
# assert(root["name"] in fns)
|
||||||
|
root = format_rulename(fns[root["name"]])
|
||||||
|
|
||||||
|
rules['root'] = root
|
||||||
|
for k in rules:
|
||||||
|
if callable(rules[k]):
|
||||||
|
rules[k] = rules[k]()
|
||||||
|
|
||||||
|
rules = [f'{k} ::= {v}' for k,v in rules.items()]
|
||||||
grammar = "\n".join(rules)
|
grammar = "\n".join(rules)
|
||||||
grammar += ( # json base types
|
grammar += ( # json base types
|
||||||
(r'''
|
(r'''
|
||||||
|
@ -137,24 +221,41 @@ integer ::= ("-"? ([0-9] | [1-9] [0-9]*))
|
||||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)?
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)?
|
||||||
'''))
|
'''))
|
||||||
return grammar
|
return grammar
|
||||||
return model_grammar(schema, root)
|
|
||||||
|
|
||||||
def make_postData(body, chat=False, stream=False):
|
def decision_grammar(schema, root):
|
||||||
|
fnames = [fn['name'] for fn in schema]
|
||||||
|
fnames = [f'"\\"" "{fn}" "\\""' for fn in fnames]
|
||||||
|
fnames = " | ".join(fnames)
|
||||||
|
rules = {}
|
||||||
|
rules["root"] = 'msg | call'
|
||||||
|
rules["msg"] = '"{" "\\"" "decision" "\\"" ":" "\\"" "message" "\\"" "}"'
|
||||||
|
rules["call"] = '"{" "\\"" "decision" "\\"" ":" "\\"" "function" "\\"" "," "\\"" "function_name" "\\"" ":" fname "}"'
|
||||||
|
rules["fname"] = fnames
|
||||||
|
rules = [f'{k} ::= {v}' for k,v in rules.items()]
|
||||||
|
grammar = "\n".join(rules)
|
||||||
|
return grammar
|
||||||
|
|
||||||
|
if root == "auto":
|
||||||
|
return decision_grammar(schema, root)
|
||||||
|
if root is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return model_grammar(schema, root)
|
||||||
|
|
||||||
|
def make_postData(body, chat=False, stream=False, decide_function=False, function_call=None):
|
||||||
postData = {}
|
postData = {}
|
||||||
if (chat):
|
if (chat):
|
||||||
postData["prompt"] = convert_chat(body["messages"])
|
postData["prompt"] = convert_chat(body["messages"])
|
||||||
else:
|
else:
|
||||||
postData["prompt"] = body["prompt"]
|
postData["prompt"] = body["prompt"]
|
||||||
|
|
||||||
if(is_present(body, "function_call") and is_present(body["function_call"], "name")):
|
if(is_present(body, "functions") and len(body["functions"])>0):
|
||||||
assert(is_present(body, "functions"))
|
assert(is_present(body, "functions"))
|
||||||
functions = {}
|
grammar = make_grammar(body["functions"], function_call)
|
||||||
for function in body["functions"]:
|
if grammar is not None:
|
||||||
functions[function['name']] = function['parameters']
|
postData["grammar"] = grammar
|
||||||
function_call = body["function_call"]["name"]
|
# print("grammar")
|
||||||
postData["grammar"] = make_grammar(functions[function_call], function_call)
|
# print(grammar)
|
||||||
print("grammar")
|
|
||||||
print(postData["grammar"])
|
|
||||||
|
|
||||||
if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
|
if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
|
||||||
if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
|
if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
|
||||||
|
@ -194,7 +295,7 @@ def make_resData(data, chat=False, promptToken=[], function_call={}):
|
||||||
if (len(promptToken) != 0):
|
if (len(promptToken) != 0):
|
||||||
resData["promptToken"] = promptToken
|
resData["promptToken"] = promptToken
|
||||||
|
|
||||||
if chat and is_present(requestBody, "function_call") and is_present(function_call, "name"):
|
if chat and is_present(function_call, "name"):
|
||||||
resData["choices"][0]["delta"] = [{
|
resData["choices"][0]["delta"] = [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"function_call": {
|
"function_call": {
|
||||||
|
@ -280,9 +381,23 @@ def chat_completions():
|
||||||
body = request.get_json()
|
body = request.get_json()
|
||||||
stream = False
|
stream = False
|
||||||
tokenize = False
|
tokenize = False
|
||||||
|
function_call = None
|
||||||
if(is_present(body, "stream")): stream = body["stream"]
|
if(is_present(body, "stream")): stream = body["stream"]
|
||||||
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
|
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
|
||||||
postData = make_postData(body, chat=True, stream=stream)
|
if(is_present(body, "function_call")): function_call = body["function_call"]
|
||||||
|
if(is_present(body, "functions") and function_call is None): function_call = "auto"
|
||||||
|
|
||||||
|
if function_call == "auto":
|
||||||
|
postDataDecide = make_postData(body, chat=True, stream=False, function_call=function_call)
|
||||||
|
dataDecide = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postDataDecide))
|
||||||
|
decision = json.loads(dataDecide.content)
|
||||||
|
decision = json.loads(decision['content'])
|
||||||
|
if decision["decision"] == "message":
|
||||||
|
function_call = None
|
||||||
|
if decision["decision"] == "function":
|
||||||
|
function_call = {"name": decision["function_name"]}
|
||||||
|
|
||||||
|
postData = make_postData(body, chat=True, stream=stream, function_call=function_call)
|
||||||
|
|
||||||
promptToken = []
|
promptToken = []
|
||||||
if (tokenize):
|
if (tokenize):
|
||||||
|
@ -292,18 +407,18 @@ def chat_completions():
|
||||||
if (not stream):
|
if (not stream):
|
||||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
|
||||||
print(data.json())
|
print(data.json())
|
||||||
resData = make_resData(data.json(), chat=True, promptToken=promptToken, function_call=body.get("function_call", {}))
|
resData = make_resData(data.json(), chat=True, promptToken=promptToken, function_call=function_call)
|
||||||
return jsonify(resData)
|
return jsonify(resData)
|
||||||
else:
|
else:
|
||||||
def generate():
|
def generate():
|
||||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
||||||
time_now = int(time.time())
|
time_now = int(time.time())
|
||||||
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True, function_call=body.get("function_call", {}))
|
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True, function_call=function_call)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n'.format(json.dumps(resData))
|
||||||
for line in data.iter_lines():
|
for line in data.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now, function_call=body.get("function_call", {}))
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now, function_call=function_call)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue