return function call in OAI format -- tools_call field

This commit is contained in:
Yingbei 2024-03-20 15:52:00 -07:00
parent 513fbf094e
commit a5b2aa58cf
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
3 changed files with 123 additions and 12 deletions

View file

@ -39,6 +39,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
if (strcmp(type, "call") == 0) {
json call = {
{"id", calls.size()},
{"name", ""},
{"args", json::array()},
{"kwargs", json::object()}
@ -50,7 +51,6 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
// Extract the function name
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
// Loop through the arguments
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
for (unsigned int i = 0; i < numArgs; ++i) {
TSNode argNode = ts_node_named_child(argumentsNode, i);
@ -58,7 +58,6 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
// Check if the argument is a positional argument or a keyword argument
if (strcmp(argType, "argument") == 0 || strcmp(argType, "positional_arguments") == 0 || strcmp(argType, "string") == 0 || strcmp(argType, "integer") == 0 || strcmp(argType, "true") == 0 || strcmp(argType, "false") == 0) {
// For simplification, we treat the entire content as the argument
std::string value = std::string(source_code + ts_node_start_byte(argNode), ts_node_end_byte(argNode) - ts_node_start_byte(argNode));
call["args"].push_back(parseValue(value));
} else if (strcmp(argType, "keyword_argument") == 0) {
@ -89,6 +88,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
}
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
// Parse Python function calls from the source code and return a JSON array
std::vector<json> calls;
std::string delimiter = "<<functions>>";
std::string source_code;
@ -116,8 +116,6 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
// Output the parsed calls
ts_tree_delete(tree);
ts_parser_delete(parser);
printf("calls: %s\n", json(calls).dump().c_str());

View file

@ -585,9 +585,22 @@ static json format_final_response_oaicompat(const json & request, json result, c
{"message", json{{"content", content},
{"role", "assistant"}}}}});
} else {
std::vector<json> oai_format_tool_calls;
for (size_t i = 0; i < parsed_content.size(); ++i) {
const auto &pc = parsed_content[i];
// Use 'pc' and 'i' as needed
json tool_call;
tool_call["id"] = pc["id"];
tool_call["type"] = "function";
tool_call["function"] = json{{
{"name" , pc["name"]},
{"arguments" , pc["kwargs"].dump()},
}};
oai_format_tool_calls.push_back(tool_call);
}
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", parsed_content},
{"message", json{{"tool_calls", oai_format_tool_calls},
{"role", "assistant"}}}}});
}
}

View file

@ -1,17 +1,24 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"\n",
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"sk-79\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n",
"def get_oai_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"sk-\"\n",
" # openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n",
" try:\n",
" completion = openai.chat.completions.create(\n",
@ -31,18 +38,111 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n"
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_NVPCmdRCtY9lO7sVhOFRPS6R', function=Function(arguments='{\"location\":\"Boston, MA\",\"unit\":\"f\"}', name='getCurrentWeather'), type='function')]))\n"
]
}
],
"source": [
"system_prompt = \"You are a helpful assistant.\"\n",
"functions = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"getCurrentWeather\",\n",
" \"description\": \"Get the weather in location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" }\n",
" },\n",
" { \"type\": \"function\",\n",
" \"function\":\n",
" {\n",
" \"name\": \"orderUmbrella\",\n",
" \"description\": \"Do this to help user to order an umbrella online\", \n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"brand_name\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The name of the umbrella brand\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"brand_name\"\n",
" ]\n",
" }\n",
" }},\n",
"]\n",
"\n",
"\n",
"user_query = \"check the weather in boston\"\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
"# ,\n",
"# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
"# {\"role\": \"user\", \"content\": \"yes pls\"},\n",
"# ]\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
"\n",
"res = get_oai_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
"print(res)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function.cpp"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id=0, function=[Function(arguments='{\"location\":\"Boston\",\"unit\":\"c\"}', name='getCurrentWeather')], type='function')]))\n"
]
}
],
"source": [
"import openai\n",
"\n",
"\n",
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"sk-\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n",
" try:\n",
" completion = openai.chat.completions.create(\n",
" model=model,\n",
" temperature=0.1,\n",
" messages=msgs,\n",
" tools=functions,\n",
" tool_choice=\"auto\",\n",
" # functions=functions,\n",
" # function_call=\"auto\",\n",
" stream=False,\n",
" )\n",
" return completion.choices[0]\n",
" except Exception as e:\n",
" print(e, model, prompt)\n",
"\n",
"system_prompt = \"You are a helpful assistant.\"\n",
"functions = [\n",
" {\n",
@ -95,7 +195,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 7,
"metadata": {},
"outputs": [
{