return function call in OAI format -- tools_call field
This commit is contained in:
parent
513fbf094e
commit
a5b2aa58cf
3 changed files with 123 additions and 12 deletions
|
@ -39,6 +39,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
|
|||
if (strcmp(type, "call") == 0) {
|
||||
|
||||
json call = {
|
||||
{"id", calls.size()},
|
||||
{"name", ""},
|
||||
{"args", json::array()},
|
||||
{"kwargs", json::object()}
|
||||
|
@ -50,7 +51,6 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
|
|||
// Extract the function name
|
||||
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
|
||||
|
||||
// Loop through the arguments
|
||||
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
|
||||
for (unsigned int i = 0; i < numArgs; ++i) {
|
||||
TSNode argNode = ts_node_named_child(argumentsNode, i);
|
||||
|
@ -58,7 +58,6 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
|
|||
|
||||
// Check if the argument is a positional argument or a keyword argument
|
||||
if (strcmp(argType, "argument") == 0 || strcmp(argType, "positional_arguments") == 0 || strcmp(argType, "string") == 0 || strcmp(argType, "integer") == 0 || strcmp(argType, "true") == 0 || strcmp(argType, "false") == 0) {
|
||||
// For simplification, we treat the entire content as the argument
|
||||
std::string value = std::string(source_code + ts_node_start_byte(argNode), ts_node_end_byte(argNode) - ts_node_start_byte(argNode));
|
||||
call["args"].push_back(parseValue(value));
|
||||
} else if (strcmp(argType, "keyword_argument") == 0) {
|
||||
|
@ -89,6 +88,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
|
|||
}
|
||||
|
||||
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
||||
// Parse Python function calls from the source code and return a JSON array
|
||||
std::vector<json> calls;
|
||||
std::string delimiter = "<<functions>>";
|
||||
std::string source_code;
|
||||
|
@ -116,8 +116,6 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
|
|||
|
||||
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
|
||||
|
||||
// Output the parsed calls
|
||||
|
||||
ts_tree_delete(tree);
|
||||
ts_parser_delete(parser);
|
||||
printf("calls: %s\n", json(calls).dump().c_str());
|
||||
|
|
|
@ -585,9 +585,22 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
|||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}}});
|
||||
} else {
|
||||
std::vector<json> oai_format_tool_calls;
|
||||
for (size_t i = 0; i < parsed_content.size(); ++i) {
|
||||
const auto &pc = parsed_content[i];
|
||||
// Use 'pc' and 'i' as needed
|
||||
json tool_call;
|
||||
tool_call["id"] = pc["id"];
|
||||
tool_call["type"] = "function";
|
||||
tool_call["function"] = json{{
|
||||
{"name" , pc["name"]},
|
||||
{"arguments" , pc["kwargs"].dump()},
|
||||
}};
|
||||
oai_format_tool_calls.push_back(tool_call);
|
||||
}
|
||||
choices = json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", parsed_content},
|
||||
{"message", json{{"tool_calls", oai_format_tool_calls},
|
||||
{"role", "assistant"}}}}});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
|
||||
" openai.api_key = \"sk-79\"\n",
|
||||
" openai.base_url = \"http://localhost:8019/v1/\"\n",
|
||||
"def get_oai_response(prompt, model, functions, msgs):\n",
|
||||
" openai.api_key = \"sk-\"\n",
|
||||
" # openai.base_url = \"http://localhost:8019/v1/\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" completion = openai.chat.completions.create(\n",
|
||||
|
@ -31,18 +38,111 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n"
|
||||
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_NVPCmdRCtY9lO7sVhOFRPS6R', function=Function(arguments='{\"location\":\"Boston, MA\",\"unit\":\"f\"}', name='getCurrentWeather'), type='function')]))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"system_prompt = \"You are a helpful assistant.\"\n",
|
||||
"functions = [\n",
|
||||
" {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": \"getCurrentWeather\",\n",
|
||||
" \"description\": \"Get the weather in location\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n",
|
||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" { \"type\": \"function\",\n",
|
||||
" \"function\":\n",
|
||||
" {\n",
|
||||
" \"name\": \"orderUmbrella\",\n",
|
||||
" \"description\": \"Do this to help user to order an umbrella online\", \n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"brand_name\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The name of the umbrella brand\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\n",
|
||||
" \"brand_name\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_query = \"check the weather in boston\"\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
|
||||
"# ,\n",
|
||||
"# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
|
||||
"# {\"role\": \"user\", \"content\": \"yes pls\"},\n",
|
||||
"# ]\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
|
||||
"\n",
|
||||
"res = get_oai_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
|
||||
"print(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Function.cpp"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id=0, function=[Function(arguments='{\"location\":\"Boston\",\"unit\":\"c\"}', name='getCurrentWeather')], type='function')]))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
|
||||
" openai.api_key = \"sk-\"\n",
|
||||
" openai.base_url = \"http://localhost:8019/v1/\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" completion = openai.chat.completions.create(\n",
|
||||
" model=model,\n",
|
||||
" temperature=0.1,\n",
|
||||
" messages=msgs,\n",
|
||||
" tools=functions,\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" # functions=functions,\n",
|
||||
" # function_call=\"auto\",\n",
|
||||
" stream=False,\n",
|
||||
" )\n",
|
||||
" return completion.choices[0]\n",
|
||||
" except Exception as e:\n",
|
||||
" print(e, model, prompt)\n",
|
||||
"\n",
|
||||
"system_prompt = \"You are a helpful assistant.\"\n",
|
||||
"functions = [\n",
|
||||
" {\n",
|
||||
|
@ -95,7 +195,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue