make changes to make sure it's an exact 1 to 1 mapping to our python rubra tool formatter

This commit is contained in:
Yingbei 2024-03-12 17:59:00 -07:00
parent aeefbbb4a3
commit 48c02498f2
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD
2 changed files with 55 additions and 24 deletions

View file

@ -338,9 +338,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
std::string final_str = "You have access to the following tools:\n";
json type_mapping = {
{"string", "str"},
{"number", "float"},
{"object", "Dict[str, Any]"},
{"array", "List"},
{"boolean", "bool"},
{"null", "None"}
};
std::vector<std::string> function_definitions;
for (const auto & function : functions) {
const auto &spec = function["function"];
const auto &spec = function.contains("function") ? function["function"] : function;
const std::string func_name = spec.value("name", "");
const std::string description = spec.value("description", "");
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
@ -350,11 +358,13 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
const std::string param = it.key();
const json& details = it.value();
std::string type_annotation = details["type"].get<std::string>();
std::string json_type = details["type"].get<std::string>();
std::string python_type = type_mapping.value(json_type, "Any");
// TODO: handle the case array: should provide more details about type, such as List[str]
if (details.contains("enum")) {
type_annotation = "str";
python_type = "str";
}
std::string arg_str = param + ": " + type_annotation;
std::string arg_str = param + ": " + python_type;
if (find(required_params.begin(), required_params.end(), param) == required_params.end()) {
arg_str += " = None";
}
@ -367,18 +377,40 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
}
// Generating Python-like docstring
std::string docstring = " \"\"\"\n " + description + "\n";
std::string docstring = " \"\"\"\n " + description + "\n\n";
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
const std::string param = it.key();
const json& details = it.value();
const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "Required" : "Optional";
docstring += " :param " + param + ": " + (details.contains("description") ? details.at("description").get<std::string>() : "No description provided.") + " (" + required_text + ")\n";
docstring += " :type " + param + ": " + details.at("type").get<std::string>() + "\n";
const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "" : "(Optional)";
std::string param_description = "";
if (details.count("description") > 0) {
param_description = details["description"]; // Assuming the description is the first element
}
if (details.count("enum") > 0) {
std::string enum_values;
for (const std::string val : details["enum"]) {
if (!enum_values.empty()) {
enum_values += " or ";
}
enum_values = enum_values+ "\"" + val + "\"";
}
if (details["enum"].size() == 1) {
param_description += " Only Acceptable value is: " + enum_values;
} else {
param_description += " Only Acceptable values are: " + enum_values;
}
}
if (param_description.empty()) {
param_description = "No description provided.";
}
docstring += " :param " + param + ": " + param_description + " " + required_text + "\n";
std::string param_type = details["type"].get<std::string>();
docstring += " :type " + param + ": " + type_mapping.value(param_type, "Any") + "\n";
}
docstring += " \"\"\"\n";
// Keeping the function definition in Python format
std::string function_definition = "def " + func_name + "(" + func_args_str + ") -> None:\n" + docstring + " pass\n";
std::string function_definition = "def " + func_name + "(" + func_args_str + "):\n" + docstring;
function_definitions.push_back(function_definition);
}

View file

@ -2,14 +2,14 @@
"cells": [
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 146,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"\n",
"def get_gorilla_response(prompt, model, functions, msgs):\n",
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
" openai.api_key = \"abc\"\n",
" openai.base_url = \"http://localhost:8019/v1/\"\n",
" \n",
@ -29,14 +29,14 @@
},
{
"cell_type": "code",
"execution_count": 136,
"execution_count": 161,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" I'm sorry to hear that it's raining in Boston. If you need help ordering an umbrella, I can assist with that. Could you please tell me the brand name you prefer?\n"
" The current weather in Boston is 72 degrees Fahrenheit and it's raining.\n"
]
}
],
@ -79,29 +79,28 @@
"]\n",
"\n",
"\n",
"user_query = \"What's the weather like in Boston? if it's raining, can you help me order an umbrella?\"\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>The current weather in Boston is 72 degrees f and rainy.\"}\n",
"user_query = \"check the weather in boston\"\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
" ,\n",
" # {\"role\": \"assistant\", \"content\": \"The current weather in Boston, MA is 72 degrees Fahrenheit and rainy. It might be a good idea to carry an umbrella with you.\"},\n",
" # {\"role\": \"user\", \"content\": \"can you help me order an umbrella?\"},\n",
" {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
" {\"role\": \"user\", \"content\": \"yes pls\"},\n",
" ]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
"\n",
"res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"print(res.message.content)"
]
},
{
"cell_type": "code",
"execution_count": 134,
"execution_count": 160,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" <<functions>>[get_stock_price(symbol=\"TSLA\")]\n",
"<<functions>>[get_stock_price(symbol=\"GOOG\")]\n"
" The current stock price of Tesla (TSLA) is $170 and the current stock price of Google (GOOG) is $138.\n"
]
}
],
@ -151,9 +150,9 @@
"user_query = \"What's the stock price of Tesla and Google?\"\n",
"\n",
"# \n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>The current stock price of Tesla is $170, which has decreased 15 percent since 2024, and the current stock price of Google is $138, also not very good.\"}]\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
"res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>170, 138\"}]\n",
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
"print(res.message.content)"
]
}