From a5b2aa58cf31e42704ea98ef17a25d3f88d9db6e Mon Sep 17 00:00:00 2001 From: Yingbei Date: Wed, 20 Mar 2024 15:52:00 -0700 Subject: [PATCH] return function call in OAI format -- tools_call field --- examples/server/python-parser.hpp | 6 +- examples/server/utils.hpp | 15 +++- test_llamacpp.ipynb | 114 ++++++++++++++++++++++++++++-- 3 files changed, 123 insertions(+), 12 deletions(-) diff --git a/examples/server/python-parser.hpp b/examples/server/python-parser.hpp index 83d70f219..8efff4955 100644 --- a/examples/server/python-parser.hpp +++ b/examples/server/python-parser.hpp @@ -39,6 +39,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector& calls, con if (strcmp(type, "call") == 0) { json call = { + {"id", calls.size()}, {"name", ""}, {"args", json::array()}, {"kwargs", json::object()} @@ -50,7 +51,6 @@ static void parseFunctionCalls(const TSNode& node, std::vector& calls, con // Extract the function name call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode)); - // Loop through the arguments unsigned int numArgs = ts_node_named_child_count(argumentsNode); for (unsigned int i = 0; i < numArgs; ++i) { TSNode argNode = ts_node_named_child(argumentsNode, i); @@ -58,7 +58,6 @@ static void parseFunctionCalls(const TSNode& node, std::vector& calls, con // Check if the argument is a positional argument or a keyword argument if (strcmp(argType, "argument") == 0 || strcmp(argType, "positional_arguments") == 0 || strcmp(argType, "string") == 0 || strcmp(argType, "integer") == 0 || strcmp(argType, "true") == 0 || strcmp(argType, "false") == 0) { - // For simplification, we treat the entire content as the argument std::string value = std::string(source_code + ts_node_start_byte(argNode), ts_node_end_byte(argNode) - ts_node_start_byte(argNode)); call["args"].push_back(parseValue(value)); } else if (strcmp(argType, "keyword_argument") == 0) { @@ -89,6 +88,7 @@ static void parseFunctionCalls(const TSNode& node, std::vector& calls, con } static std::vector parsePythonFunctionCalls(std::string source_string) { + // Parse Python function calls from the source code and return a JSON array std::vector calls; std::string delimiter = "<>"; std::string source_code; @@ -116,8 +116,6 @@ static std::vector parsePythonFunctionCalls(std::string source_string) { parseFunctionCalls(root_node, calls, source_code_cstr, 0); - // Output the parsed calls - ts_tree_delete(tree); ts_parser_delete(parser); printf("calls: %s\n", json(calls).dump().c_str()); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 73cd6dd44..d3727fc6c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -585,9 +585,22 @@ static json format_final_response_oaicompat(const json & request, json result, c {"message", json{{"content", content}, {"role", "assistant"}}}}}); } else { + std::vector oai_format_tool_calls; + for (size_t i = 0; i < parsed_content.size(); ++i) { + const auto &pc = parsed_content[i]; + // Use 'pc' and 'i' as needed + json tool_call; + tool_call["id"] = pc["id"]; + tool_call["type"] = "function"; + tool_call["function"] = json{{ + {"name" , pc["name"]}, + {"arguments" , pc["kwargs"].dump()}, + }}; + oai_format_tool_calls.push_back(tool_call); + } choices = json::array({json{{"finish_reason", finish_reason}, {"index", 0}, - {"message", json{{"content", parsed_content}, + {"message", json{{"tool_calls", oai_format_tool_calls}, {"role", "assistant"}}}}}); } } diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb index 4377a68c8..4121a9c84 100644 --- a/test_llamacpp.ipynb +++ b/test_llamacpp.ipynb @@ -1,17 +1,24 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI" + ] + }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import openai\n", "\n", "\n", - "def get_mistral_rubra_response(prompt, model, functions, msgs):\n", - " openai.api_key = \"sk-79\"\n", - " openai.base_url = \"http://localhost:8019/v1/\"\n", + "def get_oai_response(prompt, model, functions, msgs):\n", + " openai.api_key = \"sk-\"\n", + " # openai.base_url = \"http://localhost:8019/v1/\"\n", " \n", " try:\n", " completion = openai.chat.completions.create(\n", @@ -31,18 +38,111 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=[{'args': [], 'kwargs': {'location': 'Boston', 'unit': 'c'}, 'name': 'getCurrentWeather'}], role='assistant', function_call=None, tool_calls=None))\n" + "Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_NVPCmdRCtY9lO7sVhOFRPS6R', function=Function(arguments='{\"location\":\"Boston, MA\",\"unit\":\"f\"}', name='getCurrentWeather'), type='function')]))\n" ] } ], "source": [ + "system_prompt = \"You are a helpful assistant.\"\n", + "functions = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"getCurrentWeather\",\n", + " \"description\": \"Get the weather in location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\"type\": \"string\", \"description\": \"The city and state e.g. San Francisco, CA\"},\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"c\", \"f\"]}\n", + " },\n", + " \"required\": [\"location\"]\n", + " }\n", + " }\n", + " },\n", + " { \"type\": \"function\",\n", + " \"function\":\n", + " {\n", + " \"name\": \"orderUmbrella\",\n", + " \"description\": \"Do this to help user to order an umbrella online\", \n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"brand_name\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The name of the umbrella brand\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"brand_name\"\n", + " ]\n", + " }\n", + " }},\n", + "]\n", + "\n", + "\n", + "user_query = \"check the weather in boston\"\n", + "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>72 f, rainy.\"}\n", + "# ,\n", + "# {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", + "# {\"role\": \"user\", \"content\": \"yes pls\"},\n", + "# ]\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", + "\n", + "res = get_oai_response(user_query, \"gpt-4-0125-preview\", functions=functions, msgs=msgs)\n", + "print(res)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Function.cpp" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id=0, function=[Function(arguments='{\"location\":\"Boston\",\"unit\":\"c\"}', name='getCurrentWeather')], type='function')]))\n" + ] + } + ], + "source": [ + "import openai\n", + "\n", + "\n", + "def get_mistral_rubra_response(prompt, model, functions, msgs):\n", + " openai.api_key = \"sk-\"\n", + " openai.base_url = \"http://localhost:8019/v1/\"\n", + " \n", + " try:\n", + " completion = openai.chat.completions.create(\n", + " model=model,\n", + " temperature=0.1,\n", + " messages=msgs,\n", + " tools=functions,\n", + " tool_choice=\"auto\",\n", + " # functions=functions,\n", + " # function_call=\"auto\",\n", + " stream=False,\n", + " )\n", + " return completion.choices[0]\n", + " except Exception as e:\n", + " print(e, model, prompt)\n", + "\n", "system_prompt = \"You are a helpful assistant.\"\n", "functions = [\n", " {\n", @@ -95,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 7, "metadata": {}, "outputs": [ {