merge common_tool_calls into common_chat_msg

This commit is contained in:
Olivier Chafik 2025-01-22 11:05:05 +00:00
parent 01b345be0f
commit 82b6e9a5c3
6 changed files with 33 additions and 32 deletions

View file

@ -604,10 +604,17 @@ std::string common_detokenize(
// Chat template utils // Chat template utils
// //
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
// same with llama_chat_message, but uses std::string // same with llama_chat_message, but uses std::string
struct common_chat_msg { struct common_chat_msg {
std::string role; std::string role;
std::string content; std::string content;
std::vector<common_tool_call> tool_calls;
}; };
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View file

@ -1015,4 +1015,3 @@ std::string build_grammar(const std::function<void(const llama_grammar_builder &
converter.check_errors(); converter.check_errors();
return converter.format_grammar(); return converter.format_grammar();
} }

View file

@ -150,10 +150,11 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content. * Aggregates the prefix, suffix and in-between text into the content.
*/ */
static common_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
std::smatch match; std::smatch match;
common_tool_calls result; common_chat_msg result;
result.role = "assistant";
auto end = input.end(); auto end = input.end();
auto it = input.begin(); auto it = input.begin();
@ -202,7 +203,7 @@ static common_tool_calls parse_json_tool_calls(const json & tools, const std::st
return result; return result;
} }
static common_tool_calls parse_hermes_tool_calls(const std::string& input) { static common_chat_msg parse_hermes_tool_calls(const std::string& input) {
try { try {
std::regex start_pattern(R"([\n\s]*<tool_call>)"); std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)"); std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
@ -212,10 +213,11 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
std::sregex_iterator rend; std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern); std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) { if (rit == rend) {
return {input, {}}; return {"assistant", input, {}};
} }
common_tool_calls result; common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix(); result.content = rit->prefix();
auto it = rit->suffix().first; auto it = rit->suffix().first;
@ -242,16 +244,17 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
} }
return result; return result;
} catch (const std::exception & e) { } catch (const std::exception & e) {
return {input, {}}; return {"assistant", input, {}};
} }
} }
static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
if (allow_python_tag) { if (allow_python_tag) {
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
return { return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(), /* .content = */ match.prefix().str(),
/* .tool_calls = */ { /* .tool_calls = */ {
{ {
@ -268,12 +271,13 @@ static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std:
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
} }
static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool. // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
return { return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(), /* .content = */ match.prefix().str(),
/* .tool_calls = */ { /* .tool_calls = */ {
{ {
@ -289,15 +293,16 @@ static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json &
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
} }
static common_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))"); static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
} }
static common_tool_calls parse_generic_tool_calls(const std::string& input) { static common_chat_msg parse_generic_tool_calls(const std::string& input) {
json data = json::parse(input); json data = json::parse(input);
common_tool_calls result; common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) { if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) { for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({ result.tool_calls.push_back({
@ -319,11 +324,12 @@ static common_tool_calls parse_generic_tool_calls(const std::string& input) {
return result; return result;
} }
static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
auto content_end = input.find(prefix); auto content_end = input.find(prefix);
size_t tc_start = std::string::npos; size_t tc_start = std::string::npos;
common_tool_calls result; common_chat_msg result;
result.role = "assistant";
const auto process_tool_calls = [&](const json & tool_calls) { const auto process_tool_calls = [&](const json & tool_calls) {
for (const auto & tool_call : tool_calls) { for (const auto & tool_call : tool_calls) {
const auto & arguments = tool_call["arguments"]; const auto & arguments = tool_call["arguments"];
@ -345,19 +351,19 @@ static common_tool_calls parse_prefixed_json_tool_call_array(const std::string&
return result; return result;
} }
static common_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
} }
static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) { static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
} }
common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) { common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str()); fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
switch (style) { switch (style) {
case COMMON_TOOL_CALL_STYLE_NONE: case COMMON_TOOL_CALL_STYLE_NONE:
return {input, {}}; return {"assistant", input, {}};
case COMMON_TOOL_CALL_STYLE_GENERIC: case COMMON_TOOL_CALL_STYLE_GENERIC:
return parse_generic_tool_calls(input); return parse_generic_tool_calls(input);
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1: case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:

View file

@ -21,17 +21,6 @@ enum common_tool_call_style {
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
}; };
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
struct common_tool_calls {
std::string content;
std::vector<common_tool_call> tool_calls;
};
struct common_tool_call_handler { struct common_tool_call_handler {
std::string prompt; std::string prompt;
std::string grammar; std::string grammar;
@ -43,7 +32,7 @@ std::string common_tool_call_style_name(common_tool_call_style style);
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template); common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
common_tool_calls parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
common_tool_call_handler common_tool_call_handler_init( common_tool_call_handler common_tool_call_handler_init(
common_tool_call_style style, common_tool_call_style style,

View file

@ -687,7 +687,7 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop"; finish_reason = "stop";
} }
common_tool_calls parsed_tool_calls; common_chat_msg parsed_tool_calls;
json tool_calls; json tool_calls;
json message_content; json message_content;
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) { if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {

View file

@ -377,7 +377,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
} }
chat.push_back({role, content}); chat.push_back({role, content, /* tool_calls= */ {}});
} }
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);