support tools call field. tested on both local and openai
This commit is contained in:
parent
155eeed9ae
commit
177414d522
2 changed files with 41 additions and 20 deletions
|
@ -480,9 +480,9 @@ static json oaicompat_completion_params_parse(
|
|||
|
||||
std::string function_str = "";
|
||||
|
||||
if (body.contains("tool") && !body["tool"].empty()) {
|
||||
if (body.contains("tools") && !body["tools"].empty()) {
|
||||
// function_str = default_tool_formatter(body["tool"]);
|
||||
function_str = rubra_format_function_call_str(body["tool"]);
|
||||
function_str = rubra_format_function_call_str(body["tools"]);
|
||||
}
|
||||
// If 'tool' is not set or empty, check 'functions'
|
||||
else if (body.contains("functions") && !body["functions"].empty()) {
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 146,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -10,7 +10,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
|
||||
" openai.api_key = \"abc\"\n",
|
||||
" openai.api_key = \"sk-79\"\n",
|
||||
" openai.base_url = \"http://localhost:8019/v1/\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
|
@ -18,8 +18,10 @@
|
|||
" model=model,\n",
|
||||
" temperature=0.1,\n",
|
||||
" messages=msgs,\n",
|
||||
" functions=functions,\n",
|
||||
" function_call=\"auto\",\n",
|
||||
" tools=functions,\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" # functions=functions,\n",
|
||||
" # function_call=\"auto\",\n",
|
||||
" stream=False,\n",
|
||||
" )\n",
|
||||
" return completion.choices[0]\n",
|
||||
|
@ -29,14 +31,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 163,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" <<functions>>[orderUmbrella(brand_name=\"Patagonia\")]\n"
|
||||
"Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -80,20 +82,20 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"user_query = \"check the weather in boston\"\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
|
||||
" ,\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"yes pls\"},\n",
|
||||
" ]\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
|
||||
"# ,\n",
|
||||
"# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
|
||||
"# {\"role\": \"user\", \"content\": \"yes pls\"},\n",
|
||||
"# ]\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
|
||||
"\n",
|
||||
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
||||
"print(res.message.content)"
|
||||
"res = get_mistral_rubra_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n",
|
||||
"print(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 177,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -159,14 +161,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 183,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<ast.List object at 0x105649390>\n",
|
||||
"[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n"
|
||||
]
|
||||
}
|
||||
|
@ -216,7 +217,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -282,6 +283,26 @@
|
|||
"\n",
|
||||
"print(functions)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"46"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue