diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1c10063ee..c17dbcd84 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -338,9 +338,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector & functions) { std::string final_str = "You have access to the following tools:\n"; + json type_mapping = { + {"string", "str"}, + {"number", "float"}, + {"object", "Dict[str, Any]"}, + {"array", "List"}, + {"boolean", "bool"}, + {"null", "None"} + }; std::vector function_definitions; for (const auto & function : functions) { - const auto &spec = function["function"]; + const auto &spec = function.contains("function") ? function["function"] : function; const std::string func_name = spec.value("name", ""); const std::string description = spec.value("description", ""); const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({}); @@ -350,11 +358,13 @@ static std::string rubra_format_function_call_str(const std::vector & func for (auto it = parameters.begin(); it != parameters.end(); ++it) { const std::string param = it.key(); const json& details = it.value(); - std::string type_annotation = details["type"].get(); + std::string json_type = details["type"].get(); + std::string python_type = type_mapping.value(json_type, "Any"); + // TODO: handle the case array: should provide more details about type, such as List[str] if (details.contains("enum")) { - type_annotation = "str"; + python_type = "str"; } - std::string arg_str = param + ": " + type_annotation; + std::string arg_str = param + ": " + python_type; if (find(required_params.begin(), required_params.end(), param) == required_params.end()) { arg_str += " = None"; } @@ -367,18 +377,40 @@ static std::string rubra_format_function_call_str(const std::vector & func } // Generating Python-like docstring - std::string docstring = " \"\"\"\n " + description + "\n"; + std::string docstring = " \"\"\"\n " + description + "\n\n"; for (auto it = parameters.begin(); it != parameters.end(); ++it) { const std::string param = it.key(); const json& details = it.value(); - const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "Required" : "Optional"; - docstring += " :param " + param + ": " + (details.contains("description") ? details.at("description").get() : "No description provided.") + " (" + required_text + ")\n"; - docstring += " :type " + param + ": " + details.at("type").get() + "\n"; + const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "" : "(Optional)"; + std::string param_description = ""; + if (details.count("description") > 0) { + param_description = details["description"]; // Assuming the description is the first element + } + if (details.count("enum") > 0) { + std::string enum_values; + for (const std::string val : details["enum"]) { + if (!enum_values.empty()) { + enum_values += " or "; + } + enum_values = enum_values+ "\"" + val + "\""; + } + if (details["enum"].size() == 1) { + param_description += " Only Acceptable value is: " + enum_values; + } else { + param_description += " Only Acceptable values are: " + enum_values; + } + } + if (param_description.empty()) { + param_description = "No description provided."; + } + docstring += " :param " + param + ": " + param_description + " " + required_text + "\n"; + std::string param_type = details["type"].get(); + docstring += " :type " + param + ": " + type_mapping.value(param_type, "Any") + "\n"; } docstring += " \"\"\"\n"; // Keeping the function definition in Python format - std::string function_definition = "def " + func_name + "(" + func_args_str + ") -> None:\n" + docstring + " pass\n"; + std::string function_definition = "def " + func_name + "(" + func_args_str + "):\n" + docstring; function_definitions.push_back(function_definition); } diff --git a/test_llamacpp.ipynb b/test_llamacpp.ipynb index c33b9ea79..15781f283 100644 --- a/test_llamacpp.ipynb +++ b/test_llamacpp.ipynb @@ -2,14 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 106, + "execution_count": 146, "metadata": {}, "outputs": [], "source": [ "import openai\n", "\n", "\n", - "def get_gorilla_response(prompt, model, functions, msgs):\n", + "def get_mistral_rubra_response(prompt, model, functions, msgs):\n", " openai.api_key = \"abc\"\n", " openai.base_url = \"http://localhost:8019/v1/\"\n", " \n", @@ -29,14 +29,14 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 161, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " I'm sorry to hear that it's raining in Boston. If you need help ordering an umbrella, I can assist with that. Could you please tell me the brand name you prefer?\n" + " The current weather in Boston is 72 degrees Fahrenheit and it's raining.\n" ] } ], @@ -79,29 +79,28 @@ "]\n", "\n", "\n", - "user_query = \"What's the weather like in Boston? if it's raining, can you help me order an umbrella?\"\n", - "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>The current weather in Boston is 72 degrees f and rainy.\"}\n", + "user_query = \"check the weather in boston\"\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<>72 f, rainy.\"}\n", " ,\n", - " # {\"role\": \"assistant\", \"content\": \"The current weather in Boston, MA is 72 degrees Fahrenheit and rainy. It might be a good idea to carry an umbrella with you.\"},\n", - " # {\"role\": \"user\", \"content\": \"can you help me order an umbrella?\"},\n", + " {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n", + " {\"role\": \"user\", \"content\": \"yes pls\"},\n", " ]\n", "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n", "\n", - "res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", + "res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", "print(res.message.content)" ] }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 160, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " <>[get_stock_price(symbol=\"TSLA\")]\n", - "<>[get_stock_price(symbol=\"GOOG\")]\n" + " The current stock price of Tesla (TSLA) is $170 and the current stock price of Google (GOOG) is $138.\n" ] } ], @@ -151,9 +150,9 @@ "user_query = \"What's the stock price of Tesla and Google?\"\n", "\n", "# \n", - "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[get_stock_price(symbol=\"TSLA\")], <>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<>The current stock price of Tesla is $170, which has decreased 15 percent since 2024, and the current stock price of Google is $138, also not very good.\"}]\n", - "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n", - "res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", + "msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<>[get_stock_price(symbol=\"TSLA\")], <>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<>170, 138\"}]\n", + "# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n", + "res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n", "print(res.message.content)" ] }