diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 62d04e688..4320fda2a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -480,9 +480,9 @@ static json oaicompat_completion_params_parse( std::string function_str = ""; - if (body.contains("tool") && !body["tool"].empty()) { + if (body.contains("tools") && !body["tools"].empty()) { // function_str = default_tool_formatter(body["tool"]); - function_str = rubra_format_function_call_str(body["tool"]); + function_str = rubra_format_function_call_str(body["tools"]); } // If 'tool' is not set or empty, check 'functions' else if (body.contains("functions") && !body["functions"].empty()) { diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb index 5302ac8b8..4377a68c8 100644 --- a/test_llamacpp.ipynb +++ b/test_llamacpp.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 146, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -10,7 +10,7 @@ "\n", "\n", "def get_mistral_rubra_response(prompt, model, functions, msgs):\n", - " openai.api_key = \"abc\"\n", + " openai.api_key = \"sk-79\"\n", " openai.base_url = \"http://localhost:8019/v1/\"\n", " \n", " try:\n", @@ -18,8 +18,10 @@ " model=model,\n", " temperature=0.1,\n", " messages=msgs,\n", - " functions=functions,\n", - " function_call=\"auto\",\n", + " tools=functions,\n", + " tool_choice=\"auto\",\n", + " # functions=functions,\n", + " # function_call=\"auto\",\n", " stream=False,\n", " )\n", " return completion.choices[0]\n", @@ -29,14 +31,14 @@ }, { "cell_type": "code", - "execution_count": 163, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " <>[orderUmbrella(brand_name=\"Patagonia\")]\n" + "Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n" ] } ], @@ -80,20 +82,20 @@ "\n", "\n", "user_query = \"check the weather in boston\"\n", - "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>72 f, rainy.\"}\n", - " ,\n", - " {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", - " {\"role\": \"user\", \"content\": \"yes pls\"},\n", - " ]\n", - "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", + "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>72 f, rainy.\"}\n", + "# ,\n", + "# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", + "# {\"role\": \"user\", \"content\": \"yes pls\"},\n", + "# ]\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", "\n", - "res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", - "print(res.message.content)" + "res = get_mistral_rubra_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n", + "print(res)" ] }, { "cell_type": "code", - "execution_count": 177, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -159,14 +161,13 @@ }, { "cell_type": "code", - "execution_count": 183, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n", "[('get_current_weather', [], {'location': 'Boston, MA', 'api_key': 123456789, 'unit': 'fahrenheit'}), ('func', ['cde'], {'x': 1, 'b': '2', 'c': [1, 2, {'a': 1, 'b': 2}]})]\n" ] } @@ -216,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -282,6 +283,26 @@ "\n", "print(functions)\n" ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "46" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n" + ] } ], "metadata": {