From 784fa90cbe7c86c1faeb78105e1bcd0ae8938bf7 Mon Sep 17 00:00:00 2001 From: Yingbei Date: Mon, 25 Mar 2024 17:00:26 -0700 Subject: [PATCH] Add support to parse Openai function call input and results format to mistral_rubra format. TODO: need to clean up prints after testing. --- examples/server/python-parser.hpp | 2 +- examples/server/utils.hpp | 125 ++++++++++++-- llama.cpp | 3 +- test_llamacpp.ipynb | 274 ++++++++++++++---------------- 4 files changed, 249 insertions(+), 155 deletions(-) diff --git a/examples/server/python-parser.hpp b/examples/server/python-parser.hpp index 8efff4955..5037e381d 100644 --- a/examples/server/python-parser.hpp +++ b/examples/server/python-parser.hpp @@ -39,7 +39,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector& calls, con if (strcmp(type, "call") == 0) { json call = { - {"id", calls.size()}, + {"id", std::to_string(calls.size())}, {"name", ""}, {"args", json::array()}, {"kwargs", json::object()} diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f12b5307c..d94ffe32d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" @@ -491,19 +492,123 @@ static json oaicompat_completion_params_parse( // function_str = default_tool_formatter(body["functions"]); function_str = rubra_format_function_call_str(body["functions"]); } - + printf("\n=============Formatting Input from OPENAI format...============\n"); if (function_str != "") { const std::vector expand_messages = [&]() { - std::vector temp_vec = body["messages"]; - if (body["messages"][0]["role"] == "system") { - std::string old_content = temp_vec[0]["content"]; - temp_vec[0]["content"] = old_content + "\n" + function_str; + // std::vector temp_vec = body["messages"]; + // if (body["messages"][0]["role"] == "system") { + // std::string old_content = temp_vec[0]["content"]; + // temp_vec[0]["content"] = old_content + "\n" + function_str; + // } + // else { + // json function_call; + // function_call["role"] = "system"; + // function_call["content"] = "You are a helpful assistant.\n" + function_str; + // temp_vec.push_back(function_call); + // } + std::vector temp_vec; + std::unordered_map func_observation_map; + for (size_t i = 0; i < body["messages"].size(); ++i) { + printf("body[\"messages\"][%d][\"role\"] = %s\n", i, body["messages"][i]["role"].get().c_str()); + printf("Message: %s\n", body["messages"][i].dump().c_str()); + printf("%d\n", body["messages"][i].contains("tool_calls")); + + if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) { + // insert the observation from the tool call before the next message + std::string observation_str = ""; + for (const auto& [key, value] : func_observation_map) { + if (observation_str != "") { + observation_str += ", "; + } + observation_str += value; + } + observation_str = std::string("<>") + "[" + observation_str + "]"; + json observation_call; + observation_call["role"] = "observation"; + observation_call["content"] = observation_str; + temp_vec.push_back(observation_call); + func_observation_map.clear(); + } + + if (i == 0){ + if (body["messages"][0]["role"] == "system") { + std::string old_content = body["messages"][0]["content"]; + json function_call; + function_call["role"] = "system"; + function_call["content"] = old_content + "\n" + function_str; + temp_vec.push_back(function_call); + } + else { // insert a system message of tool definition before the first message + json function_call; + function_call["role"] = "system"; + function_call["content"] = "You are a helpful assistant.\n" + function_str; + temp_vec.push_back(function_call); + temp_vec.push_back(body["messages"][0]); + } + } + // else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){ + else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){ + printf("Tool call detected\n"); + // convert OpenAI function call format to Rubra format + std::string tool_call_str = ""; + printf("Tool calls: %s\n", body["messages"][i]["tool_calls"].dump().c_str()); + for (const auto & tool_call : body["messages"][i]["tool_calls"]) { + printf("Tool call id: %s\n", tool_call["id"].get().c_str()); + std::string func_str = ""; + func_observation_map[tool_call["id"].get()] = ""; // initialize with empty value and later should be updated with the actual value from "tool_call" role message + json args = json::parse(tool_call["function"]["arguments"].get()); // TODO: catch the exceptions + for (auto& arg : args.items()) { + if (func_str != "") { + func_str += ", "; + } + func_str += arg.key() + "=" + arg.value().dump(); + } + func_str = tool_call["function"]["name"].get() + "(" + func_str + ")"; + if (tool_call_str != "") { + tool_call_str += ", "; + } + tool_call_str += func_str; + } + tool_call_str = std::string("<>") + "[" + tool_call_str + "]"; + printf("Tool call string: %s\n", tool_call_str.c_str()); + + json function_call; + function_call["role"] = "function"; + function_call["content"] = tool_call_str; + temp_vec.push_back(function_call); + } + else if (body["messages"][i]["role"] == "tool") { + printf("Observation detected\n"); + printf(body["messages"][i].dump().c_str()); + std::string tool_call_id = body["messages"][i]["tool_call_id"].get(); + if (func_observation_map.find(tool_call_id) != func_observation_map.end()) { + func_observation_map[tool_call_id] = body["messages"][i]["content"].get(); + } else { + LOG_ERROR("Tool call id not found in the map", {{"tool_call_id", tool_call_id}}); + // TODO: the input is not valid in this case, should return an error + } + + } + else { + temp_vec.push_back(body["messages"][i]); + } + } - else { - json function_call; - function_call["role"] = "system"; - function_call["content"] = "You are a helpful assistant.\n" + function_str; - temp_vec.push_back(function_call); + if (func_observation_map.size() > 0) { + // insert the observation from the tool call before the next message + std::string observation_str = ""; + for (const auto& [key, value] : func_observation_map) { + if (observation_str != "") { + observation_str += ", "; + } + observation_str += value; + } + observation_str = std::string("<>") + "[" + observation_str + "]"; + json observation_call; + observation_call["role"] = "observation"; + observation_call["content"] = observation_str; + temp_vec.push_back(observation_call); + func_observation_map.clear(); } return temp_vec; }(); diff --git a/llama.cpp b/llama.cpp index ed46216b5..2b04968dd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14518,7 +14518,6 @@ static int32_t llama_chat_apply_template_internal( // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning // ss << "[INST] "; - for (auto message : chat) { std::string content = strip_message ? trim(message->content) : message->content; std::string role(message->role); @@ -14537,7 +14536,7 @@ static int32_t llama_chat_apply_template_internal( ss << "[INST]" << content << " [/INST]"; } else { ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; - is_inside_turn = false; + // is_inside_turn = false; } } // llama2 templates seem to not care about "add_generation_prompt" diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb index 2ef7032fd..df992226f 100644 --- a/test_llamacpp.ipynb +++ b/test_llamacpp.ipynb @@ -1,5 +1,82 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "def run_completion(chat_method, user_query):\n", + " print(chat_method)\n", + " system_prompt = \"You are a helpful assistant.\"\n", + " functions = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"getCurrentWeather\",\n", + " \"description\": \"Get the weather in location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n", + " },\n", + " \"required\": [\"location\"]\n", + " }\n", + " }\n", + " },\n", + " { \"type\": \"function\",\n", + " \"function\":\n", + " {\n", + " \"name\": \"orderUmbrella\",\n", + " \"description\": \"Do this to help user to order an umbrella online\", \n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"number_to_buy\": {\n", + " \"type\": \"integer\",\n", + " \"description\": \"the amount of umbrellas to buy\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"number_to_buy\"\n", + " ]\n", + " }\n", + " }},\n", + " ]\n", + " # functions = [{\"type\": \"function\",\"function\":{\"name\":\"calculate_distance\",\"description\":\"Calculate the distance between two locations\",\"parameters\":{\"type\":\"object\",\"properties\":{\"origin\":{\"type\":\"string\",\"description\":\"The starting location\"},\"destination\":{\"type\":\"string\",\"description\":\"The destination location\"},\"mode\":{\"type\":\"string\",\"description\":\"The mode of transportation\"}},\"required\":[\"origin\",\"destination\",\"mode\"]}}},{\"type\": \"function\",\"function\":{\"name\":\"generate_password\",\"description\":\"Generate a random password\",\"parameters\":{\"type\":\"object\",\"properties\":{\"length\":{\"type\":\"integer\",\"description\":\"The length of the password\"}},\"required\":[\"length\"]}}}]\n", + "\n", + " msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", + "\n", + " res = chat_method(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n", + " print(f\"First Response:\")\n", + " for tool_call in res.message.tool_calls:\n", + " print(f\"Tool Call: {tool_call.id}, {tool_call.function}\")\n", + " assistant_message = res.message\n", + " tool_calls = []\n", + " for tool_call in assistant_message.tool_calls:\n", + " tool_calls.append( {\n", + " \"id\": tool_call.id,\n", + " \"function\": {\"name\": tool_call.function.name,\n", + " \"arguments\": tool_call.function.arguments},\n", + " \"type\": \"function\",\n", + " })\n", + " msgs.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n", + " \n", + " for i, tool_call in enumerate(assistant_message.tool_calls):\n", + " if tool_call.function.name == \"getCurrentWeather\":\n", + " msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"temprature is {i * 50} degree\"})\n", + " else:\n", + " msgs.append({\"role\": \"tool\", \"tool_call_id\": str(assistant_message.tool_calls[i].id), \"name\": assistant_message.tool_calls[i].function.name, \"content\": f\"Order placed.\"})\n", + " \n", + "\n", + " print(\"Print before second response...\")\n", + " res_next = chat_method(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n", + " for m in msgs:\n", + " print(m)\n", + " print(f\"Second Response: {res_next.message}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -9,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -17,8 +94,8 @@ "\n", "\n", "def get_oai_response(prompt, model, functions, msgs):\n", - " openai.api_key = \"sk-\"\n", - " # openai.base_url = \"http://localhost:8019/v1/\"\n", + " openai.api_key = \"sk-\" ## Add your API key here\n", + " openai.base_url = \"https://api.openai.com/v1/\"\n", " \n", " try:\n", " completion = openai.chat.completions.create(\n", @@ -38,66 +115,30 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 71, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_NVPCmdRCtY9lO7sVhOFRPS6R', function=Function(arguments='{\"location\":\"Boston, MA\",\"unit\":\"f\"}', name='getCurrentWeather'), type='function')]))\n" + "\n", + "First Response:\n", + "Tool Call: call_FYLEpX5CVo2dqSyNupcgtFak, Function(arguments='{\"number_to_buy\":2}', name='orderUmbrella')\n", + "Print before second response...\n", + "{'role': 'system', 'content': 'You are a helpful assistant.'}\n", + "{'role': 'user', 'content': 'order 2 umbrellas'}\n", + "{'role': 'assistant', 'tool_calls': [{'id': 'call_FYLEpX5CVo2dqSyNupcgtFak', 'function': {'name': 'orderUmbrella', 'arguments': '{\"number_to_buy\":2}'}, 'type': 'function'}]}\n", + "{'role': 'tool', 'tool_call_id': 'call_FYLEpX5CVo2dqSyNupcgtFak', 'name': 'orderUmbrella', 'content': 'Order placed.'}\n", + "Second Response: ChatCompletionMessage(content=\"I've placed the order for 2 umbrellas for you. Is there anything else I can help with?\", role='assistant', function_call=None, tool_calls=None)\n" ] } ], "source": [ - "system_prompt = \"You are a helpful assistant.\"\n", - "functions = [\n", - " {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"getCurrentWeather\",\n", - " \"description\": \"Get the weather in location\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n", - " \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n", - " },\n", - " \"required\": [\"location\"]\n", - " }\n", - " }\n", - " },\n", - " { \"type\": \"function\",\n", - " \"function\":\n", - " {\n", - " \"name\": \"orderUmbrella\",\n", - " \"description\": \"Do this to help user to order an umbrella online\", \n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"brand_name\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The name of the umbrella brand\"\n", - " }\n", - " },\n", - " \"required\": [\n", - " \"brand_name\"\n", - " ]\n", - " }\n", - " }},\n", - "]\n", - "\n", - "\n", - "user_query = \"check the weather in boston\"\n", - "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>72 f, rainy.\"}\n", - "# ,\n", - "# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", - "# {\"role\": \"user\", \"content\": \"yes pls\"},\n", - "# ]\n", - "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", - "\n", - "res = get_oai_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n", - "print(res)" + "# user_query = \"What is the distance between San Francisco and Cupertino by car and by air\"\n", + "# user_query = \"weather in boston as well as cupertino?\"\n", + "user_query = \"order 2 umbrellas\"\n", + "run_completion(get_oai_response, user_query)" ] }, { @@ -109,17 +150,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 72, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id=0, function=[Function(arguments='{\"location\":\"Boston\",\"unit\":\"c\"}', name='getCurrentWeather')], type='function')]))\n" - ] - } - ], + "outputs": [], "source": [ "import openai\n", "\n", @@ -135,62 +168,40 @@ " messages=msgs,\n", " tools=functions,\n", " tool_choice=\"auto\",\n", - " # functions=functions,\n", - " # function_call=\"auto\",\n", " stream=False,\n", " )\n", " return completion.choices[0]\n", " except Exception as e:\n", " print(e, model, prompt)\n", - "\n", - "system_prompt = \"You are a helpful assistant.\"\n", - "functions = [\n", - " {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"getCurrentWeather\",\n", - " \"description\": \"Get the weather in location\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n", - " \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n", - " },\n", - " \"required\": [\"location\"]\n", - " }\n", - " }\n", - " },\n", - " { \"type\": \"function\",\n", - " \"function\":\n", - " {\n", - " \"name\": \"orderUmbrella\",\n", - " \"description\": \"Do this to help user to order an umbrella online\", \n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"brand_name\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The name of the umbrella brand\"\n", - " }\n", - " },\n", - " \"required\": [\n", - " \"brand_name\"\n", - " ]\n", - " }\n", - " }},\n", - "]\n", - "\n", - "\n", - "user_query = \"check the weather in boston\"\n", - "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>72 f, rainy.\"}\n", - "# ,\n", - "# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", - "# {\"role\": \"user\", \"content\": \"yes pls\"},\n", - "# ]\n", - "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", - "\n", - "res = get_mistral_rubra_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n", - "print(res)" + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "First Response:\n", + "Tool Call: 0, Function(arguments='{\"number_to_buy\":\"2\"}', name='orderUmbrella')\n", + "Print before second response...\n", + "{'role': 'system', 'content': 'You are a helpful assistant.'}\n", + "{'role': 'user', 'content': 'order 2 umbrellas'}\n", + "{'role': 'assistant', 'tool_calls': [{'id': '0', 'function': {'name': 'orderUmbrella', 'arguments': '{\"number_to_buy\":\"2\"}'}, 'type': 'function'}]}\n", + "{'role': 'tool', 'tool_call_id': '0', 'name': 'orderUmbrella', 'content': 'Order placed.'}\n", + "Second Response: ChatCompletionMessage(content=' Your order for 2 umbrellas has been placed.', role='assistant', function_call=None, tool_calls=None)\n" + ] + } + ], + "source": [ + "# user_query = \"generate a password of length 10 and another of length 20\" \n", + "# user_query = \"what's the weather in Boston and Cupertino?\"\n", + "user_query = \"order 2 umbrellas\"\n", + "run_completion(get_mistral_rubra_response, user_query)" ] }, { @@ -202,15 +213,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " <>[get_stock_fundermentals(symbol=\"TSLA\")]\n", - "<>[get_stock_fundermentals(symbol=\"GOOG\")]\n" + "None\n" ] } ], @@ -268,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -286,8 +296,8 @@ "traceback": [ "Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n", "\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.12/envs/py310/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3548\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n", - "\u001b[0m Cell \u001b[1;32mIn[19], line 40\u001b[0m\n result_dict = parse_function_call(function_call.strip())\u001b[0m\n", - "\u001b[0m Cell \u001b[1;32mIn[19], line 22\u001b[0m in \u001b[1;35mparse_function_call\u001b[0m\n parsed_value = ast.literal_eval(value)\u001b[0m\n", + "\u001b[0m Cell \u001b[1;32mIn[47], line 40\u001b[0m\n result_dict = parse_function_call(function_call.strip())\u001b[0m\n", + "\u001b[0m Cell \u001b[1;32mIn[47], line 22\u001b[0m in \u001b[1;35mparse_function_call\u001b[0m\n parsed_value = ast.literal_eval(value)\u001b[0m\n", "\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.12/lib/python3.10/ast.py:64\u001b[0m in \u001b[1;35mliteral_eval\u001b[0m\n node_or_string = parse(node_or_string.lstrip(\" \\t\"), mode='eval')\u001b[0m\n", "\u001b[0;36m File \u001b[0;32m~/.pyenv/versions/3.10.12/lib/python3.10/ast.py:50\u001b[0;36m in \u001b[0;35mparse\u001b[0;36m\n\u001b[0;31m return compile(source, filename, mode, flags,\u001b[0;36m\n", "\u001b[0;36m File \u001b[0;32m:1\u001b[0;36m\u001b[0m\n\u001b[0;31m 'Boston\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unterminated string literal (detected at line 1)\n" @@ -344,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -389,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -455,26 +465,6 @@ "\n", "print(functions)\n" ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "46" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n" - ] } ], "metadata": {