edit oai.hpp to accept function calling usage in openai format.

This commit is contained in:
Yingbei 2024-03-05 13:47:31 -08:00
parent 306d34be7a
commit 9a8762532e
No known key found for this signature in database
GPG key ID: 01CC633FE90B97CD

View file

@ -335,6 +335,108 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
// OAI utils
//
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
std::string final_str = "You have access to the following tools:\n";
std::vector<std::string> function_definitions;
for (const auto & function : functions) {
const auto &spec = function["function"];
const std::string func_name = spec.value("name", "");
const std::string description = spec.value("description", "");
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
std::vector<std::string> func_args;
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
const std::string param = it.key();
const json& details = it.value();
std::string type_annotation = details["type"].get<std::string>();
if (details.contains("enum")) {
type_annotation = "str";
}
std::string arg_str = param + ": " + type_annotation;
if (find(required_params.begin(), required_params.end(), param) == required_params.end()) {
arg_str += " = None";
}
func_args.push_back(arg_str);
}
std::string func_args_str;
for (const auto& arg : func_args) {
if (!func_args_str.empty()) func_args_str += ", ";
func_args_str += arg;
}
// Generating Python-like docstring
std::string docstring = " \"\"\"\n " + description + "\n";
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
const std::string param = it.key();
const json& details = it.value();
const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "Required" : "Optional";
docstring += " :param " + param + ": " + (details.contains("description") ? details.at("description").get<std::string>() : "No description provided.") + " (" + required_text + ")\n";
docstring += " :type " + param + ": " + details.at("type").get<std::string>() + "\n";
}
docstring += " \"\"\"\n";
// Keeping the function definition in Python format
std::string function_definition = "def " + func_name + "(" + func_args_str + ") -> None:\n" + docstring + " pass\n";
function_definitions.push_back(function_definition);
}
for (const auto& def : function_definitions) {
final_str += def + "\n";
}
final_str += "Use the following format if using a tool:\n[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]";
return final_str;
}
std::string default_tool_formatter(const std::vector<json>& tools) {
std::string toolText = "";
std::vector<std::string> toolNames;
for (const auto& tool : tools) {
json function = tool["function"];
std::string name = function["name"];
std::string description = function["description"];
json parameters = function["parameters"]["properties"];
toolText += "> Tool Name: " + name + "\nTool Description: " + description + "\nTool Args:\n";
for (auto& [key, value] : parameters.items()) {
std::string paramType = value["type"];
std::string paramDesc = value.value("description", "");
bool required = function["parameters"]["required"].contains(key);
std::string enumValues = "";
if (value.contains("enum")) {
enumValues = ", should be one of [";
for (const auto& enumValue : value["enum"]) {
enumValues += enumValue.get<std::string>() + ", ";
}
enumValues.pop_back(); // Remove last comma
enumValues.pop_back(); // Remove last space
enumValues += "]";
}
toolText += " - " + key + " (" + paramType + (required ? ", required" : "") + "): " + paramDesc + enumValues + "\n";
}
toolNames.push_back(name);
}
std::string toolNamesString = "";
for (const auto& toolName : toolNames) {
if (!toolNamesString.empty()) {
toolNamesString += ", ";
}
toolNamesString += toolName;
}
std::string formattedPrompt = "You have access to the following tools:\n" + toolText +
"Use the following format if using a tool:\n" +
"Action: tool name (one of [" + toolNamesString + "]).\n" +
"Action Input: {'arg1':'value1', 'arg2':'value2', ...}\n";
return formattedPrompt;
}
static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
@ -343,6 +445,39 @@ static json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true;
std::string function_str = "";
if (body.contains("tool") && !body["tool"].empty()) {
// function_str = default_tool_formatter(body["tool"]);
function_str = rubra_format_function_call_str(body["tool"]);
}
// If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) {
// function_str = default_tool_formatter(body["functions"]);
function_str = rubra_format_function_call_str(body["functions"]);
}
if (function_str != "") {
const std::vector<json> expand_messages = [&]() {
std::vector<json> temp_vec = body["messages"];
if (body["messages"][0]["role"] == "system") {
std::string old_content = temp_vec[0]["content"];
temp_vec[0]["content"] = old_content + "\n" + function_str;
}
else {
json function_call;
function_call["role"] = "system";
function_call["content"] = "You are a helpful assistant.\n" + function_str;
temp_vec.push_back(function_call);
}
return temp_vec;
}();
llama_params["prompt"] = format_chat(model, chat_template, expand_messages);
}
else {
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
}
// Map OpenAI parameters to llama.cpp parameters
//
// For parameters that are defined by the OpenAI documentation (e.g.
@ -352,7 +487,6 @@ static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);