make changes to make sure it's an exact 1 to 1 mapping to our python rubra tool formatter
This commit is contained in:
parent
aeefbbb4a3
commit
48c02498f2
2 changed files with 55 additions and 24 deletions
|
@ -338,9 +338,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
|||
|
||||
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
|
||||
std::string final_str = "You have access to the following tools:\n";
|
||||
json type_mapping = {
|
||||
{"string", "str"},
|
||||
{"number", "float"},
|
||||
{"object", "Dict[str, Any]"},
|
||||
{"array", "List"},
|
||||
{"boolean", "bool"},
|
||||
{"null", "None"}
|
||||
};
|
||||
std::vector<std::string> function_definitions;
|
||||
for (const auto & function : functions) {
|
||||
const auto &spec = function["function"];
|
||||
const auto &spec = function.contains("function") ? function["function"] : function;
|
||||
const std::string func_name = spec.value("name", "");
|
||||
const std::string description = spec.value("description", "");
|
||||
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
|
||||
|
@ -350,11 +358,13 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
|
|||
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
|
||||
const std::string param = it.key();
|
||||
const json& details = it.value();
|
||||
std::string type_annotation = details["type"].get<std::string>();
|
||||
std::string json_type = details["type"].get<std::string>();
|
||||
std::string python_type = type_mapping.value(json_type, "Any");
|
||||
// TODO: handle the case array: should provide more details about type, such as List[str]
|
||||
if (details.contains("enum")) {
|
||||
type_annotation = "str";
|
||||
python_type = "str";
|
||||
}
|
||||
std::string arg_str = param + ": " + type_annotation;
|
||||
std::string arg_str = param + ": " + python_type;
|
||||
if (find(required_params.begin(), required_params.end(), param) == required_params.end()) {
|
||||
arg_str += " = None";
|
||||
}
|
||||
|
@ -367,18 +377,40 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
|
|||
}
|
||||
|
||||
// Generating Python-like docstring
|
||||
std::string docstring = " \"\"\"\n " + description + "\n";
|
||||
std::string docstring = " \"\"\"\n " + description + "\n\n";
|
||||
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
|
||||
const std::string param = it.key();
|
||||
const json& details = it.value();
|
||||
const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "Required" : "Optional";
|
||||
docstring += " :param " + param + ": " + (details.contains("description") ? details.at("description").get<std::string>() : "No description provided.") + " (" + required_text + ")\n";
|
||||
docstring += " :type " + param + ": " + details.at("type").get<std::string>() + "\n";
|
||||
const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "" : "(Optional)";
|
||||
std::string param_description = "";
|
||||
if (details.count("description") > 0) {
|
||||
param_description = details["description"]; // Assuming the description is the first element
|
||||
}
|
||||
if (details.count("enum") > 0) {
|
||||
std::string enum_values;
|
||||
for (const std::string val : details["enum"]) {
|
||||
if (!enum_values.empty()) {
|
||||
enum_values += " or ";
|
||||
}
|
||||
enum_values = enum_values+ "\"" + val + "\"";
|
||||
}
|
||||
if (details["enum"].size() == 1) {
|
||||
param_description += " Only Acceptable value is: " + enum_values;
|
||||
} else {
|
||||
param_description += " Only Acceptable values are: " + enum_values;
|
||||
}
|
||||
}
|
||||
if (param_description.empty()) {
|
||||
param_description = "No description provided.";
|
||||
}
|
||||
docstring += " :param " + param + ": " + param_description + " " + required_text + "\n";
|
||||
std::string param_type = details["type"].get<std::string>();
|
||||
docstring += " :type " + param + ": " + type_mapping.value(param_type, "Any") + "\n";
|
||||
}
|
||||
docstring += " \"\"\"\n";
|
||||
|
||||
// Keeping the function definition in Python format
|
||||
std::string function_definition = "def " + func_name + "(" + func_args_str + ") -> None:\n" + docstring + " pass\n";
|
||||
std::string function_definition = "def " + func_name + "(" + func_args_str + "):\n" + docstring;
|
||||
function_definitions.push_back(function_definition);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 106,
|
||||
"execution_count": 146,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_gorilla_response(prompt, model, functions, msgs):\n",
|
||||
"def get_mistral_rubra_response(prompt, model, functions, msgs):\n",
|
||||
" openai.api_key = \"abc\"\n",
|
||||
" openai.base_url = \"http://localhost:8019/v1/\"\n",
|
||||
" \n",
|
||||
|
@ -29,14 +29,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 136,
|
||||
"execution_count": 161,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" I'm sorry to hear that it's raining in Boston. If you need help ordering an umbrella, I can assist with that. Could you please tell me the brand name you prefer?\n"
|
||||
" The current weather in Boston is 72 degrees Fahrenheit and it's raining.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -79,29 +79,28 @@
|
|||
"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_query = \"What's the weather like in Boston? if it's raining, can you help me order an umbrella?\"\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>The current weather in Boston is 72 degrees f and rainy.\"}\n",
|
||||
"user_query = \"check the weather in boston\"\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[getCurrentWeather(location=\"Boston)]'}, {\"role\": \"observation\", \"content\": \"<<observation>>72 f, rainy.\"}\n",
|
||||
" ,\n",
|
||||
" # {\"role\": \"assistant\", \"content\": \"The current weather in Boston, MA is 72 degrees Fahrenheit and rainy. It might be a good idea to carry an umbrella with you.\"},\n",
|
||||
" # {\"role\": \"user\", \"content\": \"can you help me order an umbrella?\"},\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"The current weather in Boston is 72 degrees Fahrenheit and it's raining. Would you like to order an umbrella?\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"yes pls\"},\n",
|
||||
" ]\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}]\n",
|
||||
"\n",
|
||||
"res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
||||
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
||||
"print(res.message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 134,
|
||||
"execution_count": 160,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" <<functions>>[get_stock_price(symbol=\"TSLA\")]\n",
|
||||
"<<functions>>[get_stock_price(symbol=\"GOOG\")]\n"
|
||||
" The current stock price of Tesla (TSLA) is $170 and the current stock price of Google (GOOG) is $138.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -151,9 +150,9 @@
|
|||
"user_query = \"What's the stock price of Tesla and Google?\"\n",
|
||||
"\n",
|
||||
"# \n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>The current stock price of Tesla is $170, which has decreased 15 percent since 2024, and the current stock price of Google is $138, also not very good.\"}]\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
|
||||
"res = get_gorilla_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
||||
"msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query}, {\"role\": \"function\", \"content\": '<<functions>>[get_stock_price(symbol=\"TSLA\")], <<functions>>[get_stock_price(symbol=\"GOOG\")]'}, {\"role\": \"observation\", \"content\": \"<<observation>>170, 138\"}]\n",
|
||||
"# msgs = [{\"role\": \"system\", \"content\":system_prompt} ,{\"role\": \"user\", \"content\": user_query},]\n",
|
||||
"res = get_mistral_rubra_response(user_query, \"gorilla-openfunctions-v2\", functions=functions, msgs=msgs)\n",
|
||||
"print(res.message.content)"
|
||||
]
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue