diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 48aeef4eb..b7d270b33 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -335,6 +335,108 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector & functions) { + std::string final_str = "You have access to the following tools:\n"; + std::vector function_definitions; + for (const auto & function : functions) { + const auto &spec = function["function"]; + const std::string func_name = spec.value("name", ""); + const std::string description = spec.value("description", ""); + const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({}); + const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector()) : std::vector(); + + std::vector func_args; + for (auto it = parameters.begin(); it != parameters.end(); ++it) { + const std::string param = it.key(); + const json& details = it.value(); + std::string type_annotation = details["type"].get(); + if (details.contains("enum")) { + type_annotation = "str"; + } + std::string arg_str = param + ": " + type_annotation; + if (find(required_params.begin(), required_params.end(), param) == required_params.end()) { + arg_str += " = None"; + } + func_args.push_back(arg_str); + } + std::string func_args_str; + for (const auto& arg : func_args) { + if (!func_args_str.empty()) func_args_str += ", "; + func_args_str += arg; + } + + // Generating Python-like docstring + std::string docstring = " \"\"\"\n " + description + "\n"; + for (auto it = parameters.begin(); it != parameters.end(); ++it) { + const std::string param = it.key(); + const json& details = it.value(); + const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "Required" : "Optional"; + docstring += " :param " + param + ": " + (details.contains("description") ? details.at("description").get() : "No description provided.") + " (" + required_text + ")\n"; + docstring += " :type " + param + ": " + details.at("type").get() + "\n"; + } + docstring += " \"\"\"\n"; + + // Keeping the function definition in Python format + std::string function_definition = "def " + func_name + "(" + func_args_str + ") -> None:\n" + docstring + " pass\n"; + function_definitions.push_back(function_definition); + } + + for (const auto& def : function_definitions) { + final_str += def + "\n"; + } + final_str += "Use the following format if using a tool:\n[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]"; + return final_str; +} + +std::string default_tool_formatter(const std::vector& tools) { + std::string toolText = ""; + std::vector toolNames; + for (const auto& tool : tools) { + json function = tool["function"]; + std::string name = function["name"]; + std::string description = function["description"]; + json parameters = function["parameters"]["properties"]; + + toolText += "> Tool Name: " + name + "\nTool Description: " + description + "\nTool Args:\n"; + for (auto& [key, value] : parameters.items()) { + std::string paramType = value["type"]; + std::string paramDesc = value.value("description", ""); + bool required = function["parameters"]["required"].contains(key); + std::string enumValues = ""; + + if (value.contains("enum")) { + enumValues = ", should be one of ["; + for (const auto& enumValue : value["enum"]) { + enumValues += enumValue.get() + ", "; + } + enumValues.pop_back(); // Remove last comma + enumValues.pop_back(); // Remove last space + enumValues += "]"; + } + + toolText += " - " + key + " (" + paramType + (required ? ", required" : "") + "): " + paramDesc + enumValues + "\n"; + } + + toolNames.push_back(name); + } + + std::string toolNamesString = ""; + for (const auto& toolName : toolNames) { + if (!toolNamesString.empty()) { + toolNamesString += ", "; + } + toolNamesString += toolName; + } + + std::string formattedPrompt = "You have access to the following tools:\n" + toolText + + "Use the following format if using a tool:\n" + + "Action: tool name (one of [" + toolNamesString + "]).\n" + + "Action Input: {'arg1':'value1', 'arg2':'value2', ...}\n"; + return formattedPrompt; +} + + static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ @@ -343,6 +445,39 @@ static json oaicompat_completion_params_parse( llama_params["__oaicompat"] = true; + std::string function_str = ""; + + if (body.contains("tool") && !body["tool"].empty()) { + // function_str = default_tool_formatter(body["tool"]); + function_str = rubra_format_function_call_str(body["tool"]); + } + // If 'tool' is not set or empty, check 'functions' + else if (body.contains("functions") && !body["functions"].empty()) { + // function_str = default_tool_formatter(body["functions"]); + function_str = rubra_format_function_call_str(body["functions"]); + } + + if (function_str != "") { + const std::vector expand_messages = [&]() { + std::vector temp_vec = body["messages"]; + if (body["messages"][0]["role"] == "system") { + std::string old_content = temp_vec[0]["content"]; + temp_vec[0]["content"] = old_content + "\n" + function_str; + } + else { + json function_call; + function_call["role"] = "system"; + function_call["content"] = "You are a helpful assistant.\n" + function_str; + temp_vec.push_back(function_call); + } + return temp_vec; + }(); + llama_params["prompt"] = format_chat(model, chat_template, expand_messages); + } + else { + llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); + } + // Map OpenAI parameters to llama.cpp parameters // // For parameters that are defined by the OpenAI documentation (e.g. @@ -352,7 +487,6 @@ static json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);