Tool call support (generic + native for Llama, Functionary, Hermes, Mistral, Firefunction, DeepSeek) w/ lazy grammars (#9639)

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Olivier Chafik 2025-01-30 19:13:58 +00:00 committed by GitHub
parent 27d135c970
commit 8b576b6c55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3861 additions and 156 deletions

View file

@ -17,6 +17,7 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include <random>
@ -376,7 +377,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
chat.push_back({role, content});
chat.push_back({role, content, /* tool_calls= */ {}});
}
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
@ -483,14 +484,13 @@ static bool ends_with(const std::string & str, const std::string & suffix) {
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
auto it = std::find(stop.rbegin(), stop.rend(), text.back());
while (it != stop.rend()) {
size_t length = std::distance(it, stop.rend());
if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) {
return text.length() - length;
}
it = std::find(std::next(it), stop.rend(), text.back());
}
}
@ -580,21 +580,30 @@ static json oaicompat_completion_params_parse(const json & body) {
static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
const common_chat_template & tmpl,
bool use_jinja)
bool use_jinja,
const common_chat_templates & chat_templates)
{
json llama_params;
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
? *chat_templates.template_tool_use
: *chat_templates.template_default;
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", false);
if (has_tools) {
if (use_jinja) {
LOG_WRN("tools param is not fully supported yet\n");
} else {
if (tools.is_array() && !tools.empty()) {
if (stream) {
throw std::runtime_error("Cannot use tools with stream");
}
if (!use_jinja) {
throw std::runtime_error("tools param requires --jinja flag");
}
}
if (!use_jinja) {
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
throw std::runtime_error("Unsupported param: tool_choice");
}
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@ -619,7 +628,38 @@ static json oaicompat_completion_params_parse(
// Apply chat template to the list of messages
if (use_jinja) {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
if (tool_choice != "none" && llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
common_chat_inputs inputs;
inputs.messages = body.at("messages");
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.stream = stream;
// TODO: support mixing schema w/ tools beyond generic format.
inputs.json_schema = json_value(llama_params, "json_schema", json::object());
auto chat_params = common_chat_params_init(tmpl, inputs);
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back({
{"word", trigger.word},
{"at_start", trigger.at_start},
});
}
llama_params["grammar_triggers"] = grammar_triggers;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
} else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
}
@ -638,14 +678,6 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tool_choice" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
}
}
// Copy remaining properties to llama_params
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp