support tools call field. tested on both local and openai

This commit is contained in:
Yingbei 2024-03-19 17:05:03 -07:00
parent 155eeed9ae
commit 177414d522
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
2 changed files with 41 additions and 20 deletions

View file

@ -480,9 +480,9 @@ static json oaicompat_completion_params_parse(
std::string function_str = ""; std::string function_str = "";
if (body.contains("tool") && !body["tool"].empty()) { if (body.contains("tools") && !body["tools"].empty()) {
// function_str = default_tool_formatter(body["tool"]); // function_str = default_tool_formatter(body["tool"]);
function_str = rubra_format_function_call_str(body["tool"]); function_str = rubra_format_function_call_str(body["tools"]);
} }
// If 'tool' is not set or empty, check 'functions' // If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) { else if (body.contains("functions") && !body["functions"].empty()) {

View file

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 146, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -10,7 +10,7 @@
"\n", "\n",
"\n", "\n",
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n", "def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"abc\"\n", " openai.api_key = \"sk-79\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n", " openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n", " \n",
" try:\n", " try:\n",
@ -18,8 +18,10 @@
" model=model,\n", " model=model,\n",
" temperature=0.1,\n", " temperature=0.1,\n",
" messages=msgs,\n", " messages=msgs,\n",
" functions=functions,\n", " tools=functions,\n",
" function_call=\"auto\",\n", " tool_choice=\"auto\",\n",
" # functions=functions,\n",
" # function_call=\"auto\",\n",
" stream=False,\n", " stream=False,\n",
" )\n", " )\n",
" return completion.choices[0]\n", " return completion.choices[0]\n",
@ -29,14 +31,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 163, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" <<functions>>[orderUmbrella(brand_name=\"Patagonia\")]\n" "Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n"
] ]
} }
], ],
@ -80,20 +82,20 @@
"\n", "\n",
"\n", "\n",
"user_query = \"check the weather in boston\"\n", "user_query = \"check the weather in boston\"\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n", "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
" ,\n", "# ,\n",
" {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", "# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
" {\"role\": \"user\", \"content\": \"yes pls\"},\n", "# {\"role\": \"user\", \"content\": \"yes pls\"},\n",
" ]\n", "# ]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
"\n", "\n",
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", "res = get_mistral_rubra_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
"print(res.message.content)" "print(res)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 177, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -159,14 +161,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 183, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"<ast.List object at 0x105649390>\n",
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n" "[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
] ]
} }
@ -216,7 +217,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -282,6 +283,26 @@
"\n", "\n",
"print(functions)\n" "print(functions)\n"
] ]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"46"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n"
]
} }
], ],
"metadata": { "metadata": {