merge common_tool_calls into common_chat_msg

This commit is contained in:
Olivier Chafik 2025-01-22 11:05:05 +00:00
parent 01b345be0f
commit 82b6e9a5c3
6 changed files with 33 additions and 32 deletions

View file

@ -604,10 +604,17 @@ std::string common_detokenize(
// Chat template utils
//
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View file

@ -1015,4 +1015,3 @@ std::string build_grammar(const std::function<void(const llama_grammar_builder &
converter.check_errors();
return converter.format_grammar();
}

View file

@ -150,10 +150,11 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static common_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
std::smatch match;
common_tool_calls result;
common_chat_msg result;
result.role = "assistant";
auto end = input.end();
auto it = input.begin();
@ -202,7 +203,7 @@ static common_tool_calls parse_json_tool_calls(const json & tools, const std::st
return result;
}
static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
static common_chat_msg parse_hermes_tool_calls(const std::string& input) {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
@ -212,10 +213,11 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {input, {}};
return {"assistant", input, {}};
}
common_tool_calls result;
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first;
@ -242,16 +244,17 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
}
return result;
} catch (const std::exception & e) {
return {input, {}};
return {"assistant", input, {}};
}
}
static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
if (allow_python_tag) {
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
@ -268,12 +271,13 @@ static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std:
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
}
static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
@ -289,15 +293,16 @@ static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json &
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
}
static common_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
}
static common_tool_calls parse_generic_tool_calls(const std::string& input) {
static common_chat_msg parse_generic_tool_calls(const std::string& input) {
json data = json::parse(input);
common_tool_calls result;
common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({
@ -319,11 +324,12 @@ static common_tool_calls parse_generic_tool_calls(const std::string& input) {
return result;
}
static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
auto content_end = input.find(prefix);
size_t tc_start = std::string::npos;
common_tool_calls result;
common_chat_msg result;
result.role = "assistant";
const auto process_tool_calls = [&](const json & tool_calls) {
for (const auto & tool_call : tool_calls) {
const auto & arguments = tool_call["arguments"];
@ -345,19 +351,19 @@ static common_tool_calls parse_prefixed_json_tool_call_array(const std::string&
return result;
}
static common_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) {
static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
switch (style) {
case COMMON_TOOL_CALL_STYLE_NONE:
return {input, {}};
return {"assistant", input, {}};
case COMMON_TOOL_CALL_STYLE_GENERIC:
return parse_generic_tool_calls(input);
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:

View file

@ -21,17 +21,6 @@ enum common_tool_call_style {
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
};
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
struct common_tool_calls {
std::string content;
std::vector<common_tool_call> tool_calls;
};
struct common_tool_call_handler {
std::string prompt;
std::string grammar;
@ -43,7 +32,7 @@ std::string common_tool_call_style_name(common_tool_call_style style);
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
common_tool_calls parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
common_tool_call_handler common_tool_call_handler_init(
common_tool_call_style style,

View file

@ -687,7 +687,7 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}
common_tool_calls parsed_tool_calls;
common_chat_msg parsed_tool_calls;
json tool_calls;
json message_content;
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {

View file

@ -377,7 +377,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
chat.push_back({role, content});
chat.push_back({role, content, /* tool_calls= */ {}});
}
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);