support tools call field. tested on both local and openai

This commit is contained in:
Yingbei 2024-03-19 17:05:03 -07:00
parent 155eeed9ae
commit 177414d522
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
2 changed files with 41 additions and 20 deletions

View file

@ -480,9 +480,9 @@ static json oaicompat_completion_params_parse(
std::string function_str = "";
if (body.contains("tool") && !body["tool"].empty()) {
if (body.contains("tools") && !body["tools"].empty()) {
// function_str = default_tool_formatter(body["tool"]);
function_str = rubra_format_function_call_str(body["tool"]);
function_str = rubra_format_function_call_str(body["tools"]);
}
// If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) {

View file

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 146,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -10,7 +10,7 @@
"\n",
"\n",
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"abc\"\n",
" openai.api_key = \"sk-79\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n",
" try:\n",
@ -18,8 +18,10 @@
" model=model,\n",
" temperature=0.1,\n",
" messages=msgs,\n",
" functions=functions,\n",
" function_call=\"auto\",\n",
" tools=functions,\n",
" tool_choice=\"auto\",\n",
" # functions=functions,\n",
" # function_call=\"auto\",\n",
" stream=False,\n",
" )\n",
" return completion.choices[0]\n",
@ -29,14 +31,14 @@
},
{
"cell_type": "code",
"execution_count": 163,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" <<functions>>[orderUmbrella(brand_name=\"Patagonia\")]\n"
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n"
]
}
],
@ -80,20 +82,20 @@
"\n",
"\n",
"user_query = \"check the weather in boston\"\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
" ,\n",
" {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
" {\"role\": \"user\", \"content\": \"yes pls\"},\n",
" ]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
"# ,\n",
"# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
"# {\"role\": \"user\", \"content\": \"yes pls\"},\n",
"# ]\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
"\n",
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"print(res.message.content)"
"res = get_mistral_rubra_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": 177,
"execution_count": 28,
"metadata": {},
"outputs": [
{
@ -159,14 +161,13 @@
},
{
"cell_type": "code",
"execution_count": 183,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<ast.List object at 0x105649390>\n",
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
]
}
@ -216,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 20,
"metadata": {},
"outputs": [
{
@ -282,6 +283,26 @@
"\n",
"print(functions)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"46"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n"
]
}
],
"metadata": {