support tools call field. tested on both local and openai
This commit is contained in:
parent
155eeed9ae
commit
177414d522
2 changed files with 41 additions and 20 deletions
|
@ -480,9 +480,9 @@ static json oaicompat_completion_params_parse(
|
||||||
|
|
||||||
std::string function_str = "";
|
std::string function_str = "";
|
||||||
|
|
||||||
if (body.contains("tool") && !body["tool"].empty()) {
|
if (body.contains("tools") && !body["tools"].empty()) {
|
||||||
// function_str = default_tool_formatter(body["tool"]);
|
// function_str = default_tool_formatter(body["tool"]);
|
||||||
function_str = rubra_format_function_call_str(body["tool"]);
|
function_str = rubra_format_function_call_str(body["tools"]);
|
||||||
}
|
}
|
||||||
// If 'tool' is not set or empty, check 'functions'
|
// If 'tool' is not set or empty, check 'functions'
|
||||||
else if (body.contains("functions") && !body["functions"].empty()) {
|
else if (body.contains("functions") && !body["functions"].empty()) {
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 146,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -10,7 +10,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
|
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
|
||||||
" openai.api_key = \"abc\"\n",
|
" openai.api_key = \"sk-79\"\n",
|
||||||
" openai.base_url = \"http://localhost:8019/v1/\"\n",
|
" openai.base_url = \"http://localhost:8019/v1/\"\n",
|
||||||
" \n",
|
" \n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
|
@ -18,8 +18,10 @@
|
||||||
" model=model,\n",
|
" model=model,\n",
|
||||||
" temperature=0.1,\n",
|
" temperature=0.1,\n",
|
||||||
" messages=msgs,\n",
|
" messages=msgs,\n",
|
||||||
" functions=functions,\n",
|
" tools=functions,\n",
|
||||||
" function_call=\"auto\",\n",
|
" tool_choice=\"auto\",\n",
|
||||||
|
" # functions=functions,\n",
|
||||||
|
" # function_call=\"auto\",\n",
|
||||||
" stream=False,\n",
|
" stream=False,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" return completion.choices[0]\n",
|
" return completion.choices[0]\n",
|
||||||
|
@ -29,14 +31,14 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 163,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
" <<functions>>[orderUmbrella(brand_name=\"Patagonia\")]\n"
|
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -80,20 +82,20 @@
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"user_query = \"check the weather in boston\"\n",
|
"user_query = \"check the weather in boston\"\n",
|
||||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
|
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
|
||||||
" ,\n",
|
"# ,\n",
|
||||||
" {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
|
"# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
|
||||||
" {\"role\": \"user\", \"content\": \"yes pls\"},\n",
|
"# {\"role\": \"user\", \"content\": \"yes pls\"},\n",
|
||||||
" ]\n",
|
"# ]\n",
|
||||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
|
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
"res = get_mistral_rubra_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
|
||||||
"print(res.message.content)"
|
"print(res)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 177,
|
"execution_count": 28,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -159,14 +161,13 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 183,
|
"execution_count": 19,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"<ast.List object at 0x105649390>\n",
|
|
||||||
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
|
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -216,7 +217,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 20,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -282,6 +283,26 @@
|
||||||
"\n",
|
"\n",
|
||||||
"print(functions)\n"
|
"print(functions)\n"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"46"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue