edit oai.hpp to accept function calling usage in openai format.
This commit is contained in:
parent
306d34be7a
commit
9a8762532e
1 changed files with 135 additions and 1 deletions
|
@ -335,6 +335,108 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
|||
// OAI utils
|
||||
//
|
||||
|
||||
|
||||
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
|
||||
std::string final_str = "You have access to the following tools:\n";
|
||||
std::vector<std::string> function_definitions;
|
||||
for (const auto & function : functions) {
|
||||
const auto &spec = function["function"];
|
||||
const std::string func_name = spec.value("name", "");
|
||||
const std::string description = spec.value("description", "");
|
||||
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
|
||||
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
|
||||
|
||||
std::vector<std::string> func_args;
|
||||
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
|
||||
const std::string param = it.key();
|
||||
const json& details = it.value();
|
||||
std::string type_annotation = details["type"].get<std::string>();
|
||||
if (details.contains("enum")) {
|
||||
type_annotation = "str";
|
||||
}
|
||||
std::string arg_str = param + ": " + type_annotation;
|
||||
if (find(required_params.begin(), required_params.end(), param) == required_params.end()) {
|
||||
arg_str += " = None";
|
||||
}
|
||||
func_args.push_back(arg_str);
|
||||
}
|
||||
std::string func_args_str;
|
||||
for (const auto& arg : func_args) {
|
||||
if (!func_args_str.empty()) func_args_str += ", ";
|
||||
func_args_str += arg;
|
||||
}
|
||||
|
||||
// Generating Python-like docstring
|
||||
std::string docstring = " \"\"\"\n " + description + "\n";
|
||||
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
|
||||
const std::string param = it.key();
|
||||
const json& details = it.value();
|
||||
const std::string& required_text = find(required_params.begin(), required_params.end(), param) != required_params.end() ? "Required" : "Optional";
|
||||
docstring += " :param " + param + ": " + (details.contains("description") ? details.at("description").get<std::string>() : "No description provided.") + " (" + required_text + ")\n";
|
||||
docstring += " :type " + param + ": " + details.at("type").get<std::string>() + "\n";
|
||||
}
|
||||
docstring += " \"\"\"\n";
|
||||
|
||||
// Keeping the function definition in Python format
|
||||
std::string function_definition = "def " + func_name + "(" + func_args_str + ") -> None:\n" + docstring + " pass\n";
|
||||
function_definitions.push_back(function_definition);
|
||||
}
|
||||
|
||||
for (const auto& def : function_definitions) {
|
||||
final_str += def + "\n";
|
||||
}
|
||||
final_str += "Use the following format if using a tool:\n[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]";
|
||||
return final_str;
|
||||
}
|
||||
|
||||
std::string default_tool_formatter(const std::vector<json>& tools) {
|
||||
std::string toolText = "";
|
||||
std::vector<std::string> toolNames;
|
||||
for (const auto& tool : tools) {
|
||||
json function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
std::string description = function["description"];
|
||||
json parameters = function["parameters"]["properties"];
|
||||
|
||||
toolText += "> Tool Name: " + name + "\nTool Description: " + description + "\nTool Args:\n";
|
||||
for (auto& [key, value] : parameters.items()) {
|
||||
std::string paramType = value["type"];
|
||||
std::string paramDesc = value.value("description", "");
|
||||
bool required = function["parameters"]["required"].contains(key);
|
||||
std::string enumValues = "";
|
||||
|
||||
if (value.contains("enum")) {
|
||||
enumValues = ", should be one of [";
|
||||
for (const auto& enumValue : value["enum"]) {
|
||||
enumValues += enumValue.get<std::string>() + ", ";
|
||||
}
|
||||
enumValues.pop_back(); // Remove last comma
|
||||
enumValues.pop_back(); // Remove last space
|
||||
enumValues += "]";
|
||||
}
|
||||
|
||||
toolText += " - " + key + " (" + paramType + (required ? ", required" : "") + "): " + paramDesc + enumValues + "\n";
|
||||
}
|
||||
|
||||
toolNames.push_back(name);
|
||||
}
|
||||
|
||||
std::string toolNamesString = "";
|
||||
for (const auto& toolName : toolNames) {
|
||||
if (!toolNamesString.empty()) {
|
||||
toolNamesString += ", ";
|
||||
}
|
||||
toolNamesString += toolName;
|
||||
}
|
||||
|
||||
std::string formattedPrompt = "You have access to the following tools:\n" + toolText +
|
||||
"Use the following format if using a tool:\n" +
|
||||
"Action: tool name (one of [" + toolNamesString + "]).\n" +
|
||||
"Action Input: {'arg1':'value1', 'arg2':'value2', ...}\n";
|
||||
return formattedPrompt;
|
||||
}
|
||||
|
||||
|
||||
static json oaicompat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
|
@ -343,6 +445,39 @@ static json oaicompat_completion_params_parse(
|
|||
|
||||
llama_params["__oaicompat"] = true;
|
||||
|
||||
std::string function_str = "";
|
||||
|
||||
if (body.contains("tool") && !body["tool"].empty()) {
|
||||
// function_str = default_tool_formatter(body["tool"]);
|
||||
function_str = rubra_format_function_call_str(body["tool"]);
|
||||
}
|
||||
// If 'tool' is not set or empty, check 'functions'
|
||||
else if (body.contains("functions") && !body["functions"].empty()) {
|
||||
// function_str = default_tool_formatter(body["functions"]);
|
||||
function_str = rubra_format_function_call_str(body["functions"]);
|
||||
}
|
||||
|
||||
if (function_str != "") {
|
||||
const std::vector<json> expand_messages = [&]() {
|
||||
std::vector<json> temp_vec = body["messages"];
|
||||
if (body["messages"][0]["role"] == "system") {
|
||||
std::string old_content = temp_vec[0]["content"];
|
||||
temp_vec[0]["content"] = old_content + "\n" + function_str;
|
||||
}
|
||||
else {
|
||||
json function_call;
|
||||
function_call["role"] = "system";
|
||||
function_call["content"] = "You are a helpful assistant.\n" + function_str;
|
||||
temp_vec.push_back(function_call);
|
||||
}
|
||||
return temp_vec;
|
||||
}();
|
||||
llama_params["prompt"] = format_chat(model, chat_template, expand_messages);
|
||||
}
|
||||
else {
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
||||
}
|
||||
|
||||
// Map OpenAI parameters to llama.cpp parameters
|
||||
//
|
||||
// For parameters that are defined by the OpenAI documentation (e.g.
|
||||
|
@ -352,7 +487,6 @@ static json oaicompat_completion_params_parse(
|
|||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
llama_sampling_params default_sparams;
|
||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
|
||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue