diff --git a/common/common.h b/common/common.h index 830a56fa7..96e23689e 100644 --- a/common/common.h +++ b/common/common.h @@ -604,10 +604,17 @@ std::string common_detokenize( // Chat template utils // +struct common_tool_call { + std::string name; + std::string arguments; + std::string id; +}; + // same with llama_chat_message, but uses std::string struct common_chat_msg { std::string role; std::string content; + std::vector tool_calls; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dacaa1fc3..4d426b6bd 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1015,4 +1015,3 @@ std::string build_grammar(const std::function)"); std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); @@ -212,10 +213,11 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) { std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); if (rit == rend) { - return {input, {}}; + return {"assistant", input, {}}; } - common_tool_calls result; + common_chat_msg result; + result.role = "assistant"; result.content = rit->prefix(); auto it = rit->suffix().first; @@ -242,16 +244,17 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) { } return result; } catch (const std::exception & e) { - return {input, {}}; + return {"assistant", input, {}}; } } -static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { +static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { if (allow_python_tag) { static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { return { + /* .role = */ "assistant", /* .content = */ match.prefix().str(), /* .tool_calls = */ { { @@ -268,12 +271,13 @@ static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std: return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } -static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { +static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { return { + /* .role = */ "assistant", /* .content = */ match.prefix().str(), /* .tool_calls = */ { { @@ -289,15 +293,16 @@ static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false); } -static common_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { +static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } -static common_tool_calls parse_generic_tool_calls(const std::string& input) { +static common_chat_msg parse_generic_tool_calls(const std::string& input) { json data = json::parse(input); - common_tool_calls result; + common_chat_msg result; + result.role = "assistant"; if (data.contains("tool_calls")) { for (const auto & tool_call : data["tool_calls"]) { result.tool_calls.push_back({ @@ -319,11 +324,12 @@ static common_tool_calls parse_generic_tool_calls(const std::string& input) { return result; } -static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { +static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { auto content_end = input.find(prefix); size_t tc_start = std::string::npos; - common_tool_calls result; + common_chat_msg result; + result.role = "assistant"; const auto process_tool_calls = [&](const json & tool_calls) { for (const auto & tool_call : tool_calls) { const auto & arguments = tool_call["arguments"]; @@ -345,19 +351,19 @@ static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& return result; } -static common_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { +static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) { return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); } -static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) { +static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) { return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); } -common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) { +common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) { fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str()); switch (style) { case COMMON_TOOL_CALL_STYLE_NONE: - return {input, {}}; + return {"assistant", input, {}}; case COMMON_TOOL_CALL_STYLE_GENERIC: return parse_generic_tool_calls(input); case COMMON_TOOL_CALL_STYLE_LLAMA_3_1: diff --git a/common/tool-call.h b/common/tool-call.h index 5ca422e21..37b5d9739 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -21,17 +21,6 @@ enum common_tool_call_style { COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, }; -struct common_tool_call { - std::string name; - std::string arguments; - std::string id; -}; - -struct common_tool_calls { - std::string content; - std::vector tool_calls; -}; - struct common_tool_call_handler { std::string prompt; std::string grammar; @@ -43,7 +32,7 @@ std::string common_tool_call_style_name(common_tool_call_style style); common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template); -common_tool_calls parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); +common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); common_tool_call_handler common_tool_call_handler_init( common_tool_call_style style, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 67e960a72..ca0626d99 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -687,7 +687,7 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - common_tool_calls parsed_tool_calls; + common_chat_msg parsed_tool_calls; json tool_calls; json message_content; if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3591ae0a7..17ba6b940 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -377,7 +377,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - chat.push_back({role, content}); + chat.push_back({role, content, /* tool_calls= */ {}}); } const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);