From 6e81bc5f8b7caff4d91af821abc362ce3b99680c Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 1 Oct 2023 22:41:54 +0200 Subject: [PATCH] add multiple functions, decision to send message or use function_call and include sent function results in chat prompt --- examples/server/api_like_OAI.py | 215 ++++++++++++++++++++++++-------- 1 file changed, 165 insertions(+), 50 deletions(-) diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index c35090ff7..b9ac88740 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -14,6 +14,7 @@ parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat comp parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") +parser.add_argument("--function-name", type=str, help="FUNCTION name in chat completions(default: '\\nFUNCTION: ')", default="\\nFUNCTION: ") parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '')", default="") parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") @@ -25,6 +26,8 @@ args = parser.parse_args() def is_present(json, key): try: buf = json[key] + except TypeError: + return False except KeyError: return False return True @@ -38,6 +41,7 @@ def convert_chat(messages): system_n = args.system_name.replace("\\n", "\n") user_n = args.user_name.replace("\\n", "\n") ai_n = args.ai_name.replace("\\n", "\n") + fn_n = args.function_name.replace("\\n", "\n") stop = args.stop.replace("\\n", "\n") @@ -46,34 +50,52 @@ def convert_chat(messages): prompt += f"{system_n}{line['content']}" if (line["role"] == "user"): prompt += f"{user_n}{line['content']}" + if (line["role"] == "function"): + prompt += f"{fn_n}{line['content']}" if (line["role"] == "assistant"): + # if (not is_present(line, 'content') or line['content'] is None): + # if is_present(line, 'function_call'): + # # line['content'] = json.dumps({'function_call': line['function_call']}) + # fname = line['function_call']['name'] + # line['content'] = f'call {fname}' + prompt += f"{ai_n}{line['content']}{stop}" prompt += ai_n.rstrip() return prompt def make_grammar(schema, root): + indent_inc = " " def format_rulename(name): return name.replace("_", "-") - def schema_typename(schema, defs, arrs): + def is_basetype(typename): + return typename in ['integer', 'number', 'bool', 'string'] + + def schema_typename(prefix, schema, defs, arrs): typename = '"null"' if 'type' in schema: typename = schema['type'] + if not is_basetype(typename) and typename != 'array': + typename = prefix + typename if '$ref' in schema and schema['$ref'] in defs: typename = defs[schema['$ref']] if typename == 'array': - elemtype = schema_typename(schema['items'], defs, arrs) - typename = format_rulename(f'array-{elemtype}') + elemtype = schema_typename(prefix, schema['items'], defs, arrs) + typename = f'array-{elemtype}' + if not is_basetype(elemtype): + typename = prefix + typename + typename = format_rulename(typename) if typename not in arrs: arrs[typename] = elemtype return typename - def arr_to_rule(name, elemtype): - rule = f'{name} ::= "[" ( {elemtype} ( "," {elemtype} )* )? "]"' - return rule + def arr_to_rules(rules, prefix, name, elemtype): + rulename = name + rulename = format_rulename(rulename) + rules[rulename] = f'"[" ( {elemtype} ( "," {elemtype} )* )? "]"' - def enum_to_rule(name, schema): + def enum_to_rules(rules, prefix, name, schema): enum_values = schema['enum'] etype = schema['type'] def value_pattern(value): @@ -85,44 +107,106 @@ def make_grammar(schema, root): value_pattern(value) for value in enum_values ]) - rule = f'{name} ::= ( {values} )' - return rule + rulename = format_rulename(f'{prefix}{name}') + rules[rulename] = f'( {values} )' - def obj_to_rule(name, schema, defs, arrs): - assert(schema['type'] == 'object') - def propery_to_grammar(name, typename): - return f'"\\"" "{name}" "\\"" ":" {typename}' - properties = '"," '.join([ - propery_to_grammar(name, schema_typename(property, defs, arrs)) - for name, property in schema['properties'].items() + def anyof_to_rules(rules, prefix, name, schema, defs, arrs): + values = schema['anyOf'] + values = ' | '.join([ + schema_typename(prefix, value, defs, arrs) + for value in values ]) - rule = f'{name} ::= "{{" {properties} "}}"' - return rule + rulename = format_rulename(f'{prefix}{name}') + rules[rulename] = f'( {values} )' + + def declare_rules(indent, rules, prefix, name, schema, defs, arrs): + # print(indent, "declare_rules() prefix", prefix) + # print(indent, "declare_rules() name", name) + # print(indent, "declare_rules() schema", schema) + if 'enum' in schema: + enum_to_rules(rules, prefix, format_rulename(name), schema) + elif 'anyOf' in schema: + anyof_to_rules(rules, prefix, format_rulename(name), schema, defs, arrs) + elif schema.get('type', None) == 'object': + obj_to_rules(indent + indent_inc, rules, prefix, format_rulename(name), schema, defs, arrs, is_toplevel=False) + # else: + # print(indent,"warning, did not declare any rules for (prefix,name,schema)", prefix, name, schema) + + def obj_to_rules(indent, rules, prefix, name, schema, defs, arrs, is_toplevel): + assert(schema['type'] == 'object') + if defs is None: + defs = {} + if arrs is None: + arrs = {} + + + rulename = name + if not is_toplevel: + rulename = prefix+rulename + rulename = format_rulename(rulename) + + # print(indent, "obj_to_rules() prefix", prefix) + # print(indent, "obj_to_rules() name", name) + # print(indent, "obj_to_rules() schema", schema) + # print(indent, "obj_to_rules() rulename", rulename) - def model_grammar(schema, root = None): - rules = [] - defs = {} - arrs = {} if '$defs' in schema: for name, _def in schema['$defs'].items(): - defs['#/$defs/' + name] = format_rulename(name) - + defs['#/$defs/' + name] = format_rulename(prefix + name) + + if '$defs' in schema: for name, _def in schema['$defs'].items(): - if 'enum' in _def: - rules.append(enum_to_rule(format_rulename(name), _def)) - elif _def['type'] == 'object': - rules.append(obj_to_rule(format_rulename(name), _def, defs, arrs)) + declare_rules(indent + indent_inc, rules, prefix, name, _def, defs, arrs) - if root is None: - root = schema["title"] - root = format_rulename(root) + for name, prop in schema['properties'].items(): + declare_rules(indent + indent_inc, rules, prefix, name, prop, defs, arrs) - if schema['type'] == 'object': - rules.append(obj_to_rule(root, schema, defs, arrs)) + def propery_to_grammar(name, typename): + return f'"\\"" "{name}" "\\"" ":" {typename}' + + properties = ' "," '.join([ + propery_to_grammar(name, schema_typename(prefix, prop, defs, arrs)) + for name, prop in schema['properties'].items() + ]) + # rule = f'{rulename} ::= "{{" {properties} "}}"' + rules[rulename] = f'"{{" {properties} "}}"' + # print("indent", "obj_to_rules() arrs", arrs) for arrtype, elemtype in arrs.items(): - rules.append(arr_to_rule(arrtype, elemtype)) - rules.append(f'root ::= {root}') + arr_to_rules(rules, prefix, arrtype, elemtype) + + return rulename + + def model_grammar(schema, root = None): + indent = "" + rules = {} + defs = {} + fns = {} + arrs = {} + + # print(indent, "model_grammar() schema") + # print(indent, schema) + # print(indent, "model_grammar() root", root) + + for fn in schema: + name = fn['name'] + params = fn['parameters'] + prefix = f"{name}-" + fns[name] = obj_to_rules(indent + indent_inc, rules, prefix, name, params, {}, {}, is_toplevel=True) + + # print(indent, "fns") + # print(indent, fns) + + # assert("name" in root) + # assert(root["name"] in fns) + root = format_rulename(fns[root["name"]]) + + rules['root'] = root + for k in rules: + if callable(rules[k]): + rules[k] = rules[k]() + + rules = [f'{k} ::= {v}' for k,v in rules.items()] grammar = "\n".join(rules) grammar += ( # json base types (r''' @@ -137,24 +221,41 @@ integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ''')) return grammar - return model_grammar(schema, root) -def make_postData(body, chat=False, stream=False): + def decision_grammar(schema, root): + fnames = [fn['name'] for fn in schema] + fnames = [f'"\\"" "{fn}" "\\""' for fn in fnames] + fnames = " | ".join(fnames) + rules = {} + rules["root"] = 'msg | call' + rules["msg"] = '"{" "\\"" "decision" "\\"" ":" "\\"" "message" "\\"" "}"' + rules["call"] = '"{" "\\"" "decision" "\\"" ":" "\\"" "function" "\\"" "," "\\"" "function_name" "\\"" ":" fname "}"' + rules["fname"] = fnames + rules = [f'{k} ::= {v}' for k,v in rules.items()] + grammar = "\n".join(rules) + return grammar + + if root == "auto": + return decision_grammar(schema, root) + if root is None: + return None + else: + return model_grammar(schema, root) + +def make_postData(body, chat=False, stream=False, decide_function=False, function_call=None): postData = {} if (chat): postData["prompt"] = convert_chat(body["messages"]) else: postData["prompt"] = body["prompt"] - if(is_present(body, "function_call") and is_present(body["function_call"], "name")): + if(is_present(body, "functions") and len(body["functions"])>0): assert(is_present(body, "functions")) - functions = {} - for function in body["functions"]: - functions[function['name']] = function['parameters'] - function_call = body["function_call"]["name"] - postData["grammar"] = make_grammar(functions[function_call], function_call) - print("grammar") - print(postData["grammar"]) + grammar = make_grammar(body["functions"], function_call) + if grammar is not None: + postData["grammar"] = grammar + # print("grammar") + # print(grammar) if(is_present(body, "temperature")): postData["temperature"] = body["temperature"] if(is_present(body, "top_k")): postData["top_k"] = body["top_k"] @@ -194,7 +295,7 @@ def make_resData(data, chat=False, promptToken=[], function_call={}): if (len(promptToken) != 0): resData["promptToken"] = promptToken - if chat and is_present(requestBody, "function_call") and is_present(function_call, "name"): + if chat and is_present(function_call, "name"): resData["choices"][0]["delta"] = [{ "index": 0, "function_call": { @@ -280,9 +381,23 @@ def chat_completions(): body = request.get_json() stream = False tokenize = False + function_call = None if(is_present(body, "stream")): stream = body["stream"] if(is_present(body, "tokenize")): tokenize = body["tokenize"] - postData = make_postData(body, chat=True, stream=stream) + if(is_present(body, "function_call")): function_call = body["function_call"] + if(is_present(body, "functions") and function_call is None): function_call = "auto" + + if function_call == "auto": + postDataDecide = make_postData(body, chat=True, stream=False, function_call=function_call) + dataDecide = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postDataDecide)) + decision = json.loads(dataDecide.content) + decision = json.loads(decision['content']) + if decision["decision"] == "message": + function_call = None + if decision["decision"] == "function": + function_call = {"name": decision["function_name"]} + + postData = make_postData(body, chat=True, stream=stream, function_call=function_call) promptToken = [] if (tokenize): @@ -292,18 +407,18 @@ def chat_completions(): if (not stream): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) print(data.json()) - resData = make_resData(data.json(), chat=True, promptToken=promptToken, function_call=body.get("function_call", {})) + resData = make_resData(data.json(), chat=True, promptToken=promptToken, function_call=function_call) return jsonify(resData) else: def generate(): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) time_now = int(time.time()) - resData = make_resData_stream({}, chat=True, time_now=time_now, start=True, function_call=body.get("function_call", {})) + resData = make_resData_stream({}, chat=True, time_now=time_now, start=True, function_call=function_call) yield 'data: {}\n'.format(json.dumps(resData)) for line in data.iter_lines(): if line: decoded_line = line.decode('utf-8') - resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now, function_call=body.get("function_call", {})) + resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now, function_call=function_call) yield 'data: {}\n'.format(json.dumps(resData)) return Response(generate(), mimetype='text/event-stream')