merge common_tool_calls into common_chat_msg
This commit is contained in:
parent
01b345be0f
commit
82b6e9a5c3
6 changed files with 33 additions and 32 deletions
|
@ -604,10 +604,17 @@ std::string common_detokenize(
|
|||
// Chat template utils
|
||||
//
|
||||
|
||||
struct common_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
// same with llama_chat_message, but uses std::string
|
||||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_tool_call> tool_calls;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
|
|
|
@ -1015,4 +1015,3 @@ std::string build_grammar(const std::function<void(const llama_grammar_builder &
|
|||
converter.check_errors();
|
||||
return converter.format_grammar();
|
||||
}
|
||||
|
||||
|
|
|
@ -150,10 +150,11 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
||||
std::smatch match;
|
||||
|
||||
common_tool_calls result;
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
auto end = input.end();
|
||||
auto it = input.begin();
|
||||
|
||||
|
@ -202,7 +203,7 @@ static common_tool_calls parse_json_tool_calls(const json & tools, const std::st
|
|||
return result;
|
||||
}
|
||||
|
||||
static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
|
||||
static common_chat_msg parse_hermes_tool_calls(const std::string& input) {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
|
@ -212,10 +213,11 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
|
|||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {input, {}};
|
||||
return {"assistant", input, {}};
|
||||
}
|
||||
|
||||
common_tool_calls result;
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
|
@ -242,16 +244,17 @@ static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
|
|||
}
|
||||
return result;
|
||||
} catch (const std::exception & e) {
|
||||
return {input, {}};
|
||||
return {"assistant", input, {}};
|
||||
}
|
||||
}
|
||||
|
||||
static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
|
||||
static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
|
||||
if (allow_python_tag) {
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
|
@ -268,12 +271,13 @@ static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std:
|
|||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
}
|
||||
|
||||
static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
|
||||
static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
|
@ -289,15 +293,16 @@ static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json &
|
|||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
|
||||
}
|
||||
|
||||
static common_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
|
||||
static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
}
|
||||
|
||||
static common_tool_calls parse_generic_tool_calls(const std::string& input) {
|
||||
static common_chat_msg parse_generic_tool_calls(const std::string& input) {
|
||||
json data = json::parse(input);
|
||||
common_tool_calls result;
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
|
@ -319,11 +324,12 @@ static common_tool_calls parse_generic_tool_calls(const std::string& input) {
|
|||
return result;
|
||||
}
|
||||
|
||||
static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
auto content_end = input.find(prefix);
|
||||
size_t tc_start = std::string::npos;
|
||||
|
||||
common_tool_calls result;
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call["arguments"];
|
||||
|
@ -345,19 +351,19 @@ static common_tool_calls parse_prefixed_json_tool_call_array(const std::string&
|
|||
return result;
|
||||
}
|
||||
|
||||
static common_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
|
||||
static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
}
|
||||
|
||||
static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) {
|
||||
static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
|
||||
common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
|
||||
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
|
||||
switch (style) {
|
||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
||||
return {input, {}};
|
||||
return {"assistant", input, {}};
|
||||
case COMMON_TOOL_CALL_STYLE_GENERIC:
|
||||
return parse_generic_tool_calls(input);
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
||||
|
|
|
@ -21,17 +21,6 @@ enum common_tool_call_style {
|
|||
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
|
||||
};
|
||||
|
||||
struct common_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
struct common_tool_calls {
|
||||
std::string content;
|
||||
std::vector<common_tool_call> tool_calls;
|
||||
};
|
||||
|
||||
struct common_tool_call_handler {
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
|
@ -43,7 +32,7 @@ std::string common_tool_call_style_name(common_tool_call_style style);
|
|||
|
||||
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
|
||||
|
||||
common_tool_calls parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||
common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||
|
||||
common_tool_call_handler common_tool_call_handler_init(
|
||||
common_tool_call_style style,
|
||||
|
|
|
@ -687,7 +687,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
common_tool_calls parsed_tool_calls;
|
||||
common_chat_msg parsed_tool_calls;
|
||||
json tool_calls;
|
||||
json message_content;
|
||||
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
|
||||
|
|
|
@ -377,7 +377,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec
|
|||
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||
}
|
||||
|
||||
chat.push_back({role, content});
|
||||
chat.push_back({role, content, /* tool_calls= */ {}});
|
||||
}
|
||||
|
||||
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue