diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index d5f235f21..8ea031bd5 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -288,6 +288,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem "Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```"); data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.format = "generic tool calls"; data.parser = std::make_unique([&](const std::string & input) { json data = json::parse(input); common_chat_msg result; @@ -355,7 +356,8 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); } data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); - data.parser = std::make_unique([](const std::string & input) -> common_chat_msg { + data.format = "mistral nemo tool calls"; + data.parser = std::make_unique([](const std::string & input) { return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); }); return data; @@ -397,6 +399,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, { {"builtin_tools", builtin_tools}, }); + data.format = "llama 3.1 tool calls"; data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); @@ -452,7 +455,8 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {}); - data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.format = "llama 3.2 tool calls"; + data.parser = std::make_unique([params](const std::string & input) { static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); static std::regex close_regex("\\}"); auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); @@ -482,7 +486,8 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); - data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.format = "deepseek r1 tool calls"; + data.parser = std::make_unique([params](const std::string & input) { static std::regex trigger_regex("<|tool▁calls▁begin|>"); static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); static std::regex close_regex("```<|tool▁call▁end|>"); @@ -524,13 +529,14 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); } data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); - data.parser = std::make_unique([](const std::string & input) -> common_chat_msg { + data.format = "firefunction v2 tool calls"; + data.parser = std::make_unique([](const std::string & input) { return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); }); return data; } -static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { fprintf(stderr, "[%s]\n", __func__); // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar @@ -562,7 +568,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); - data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.format = "functionary v3.2 tool calls"; + data.parser = std::make_unique([params](const std::string & input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); @@ -629,6 +636,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.format = "functionary v3.1 llama 3.1 tool calls"; data.parser = std::make_unique([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); @@ -682,6 +690,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.format = "hermes 2 pro tool calls"; data.parser = std::make_unique([&](const std::string & input) -> common_chat_msg { try { std::regex start_pattern(R"([\n\s]*)"); @@ -733,6 +742,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.format = "content-only"; data.parser = std::make_unique(); if (!params.json_schema.is_null()) { if (!params.grammar.empty()) { @@ -759,7 +769,7 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc return common_chat_init_hermes_2_pro_tool_call(tmpl, params); } if (src.find(">>>all") != std::string::npos) { - return common_chat_init_functionary_v3_llama_3_tool_call(tmpl, params); + return common_chat_init_functionary_v3_2_tool_call(tmpl, params); } if (src.find("<|start_header_id|>") != std::string::npos && src.find(" grammar_triggers; std::vector additional_stops; std::unique_ptr parser; + std::string format; // For debugging and testing. }; struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params); diff --git a/common/chat-template.hpp b/common/chat-template.hpp index f528875f7..f239e10fd 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -208,7 +208,6 @@ class chat_template { arguments = json::parse(arguments.get()); } catch (const std::exception & ecvt) { fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); - arguments = arguments; } } } diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 5e9db450e..14c441fe9 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -298,26 +298,33 @@ const json tools = {special_function_tool, python_tool}; // json::array({special_function_call})); // } -// static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) { -// const common_chat_template tmpl(read_file(template_file), "", ""); -// auto tool_call_style = common_tool_call_style_detect(tmpl); -// std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; -// assert_equals(expected, tool_call_style); -// } +static void test_format_detection() { + common_chat_params no_tools_params; + no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}}; -// static void test_tool_call_style_detection() { -// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1); -// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3); -// test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2); -// test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1); -// test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2); -// test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); -// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); -// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); -// test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS); -// test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO); -// test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC); -// } + common_chat_params tools_params = no_tools_params; + tools_params.tools = json::array(); + + auto describe = [](const std::string & template_file, const common_chat_params & params) { + const common_chat_template tmpl(read_file(template_file), "", ""); + auto data = common_chat_init(tmpl, params); + return data.format; + }; + + assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", tools_params)); + assert_equals(std::string("functionary v3.2 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", tools_params)); + assert_equals(std::string("firefunction v2 tool calls"), describe("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", tools_params)); + assert_equals(std::string("llama 3.1 tool calls"), describe("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", tools_params)); + assert_equals(std::string("llama 3.2 tool calls"), describe("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", tools_params)); + assert_equals(std::string("mistral nemo tool calls"), describe("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", tools_params)); + assert_equals(std::string("deepseek r1 tool calls"), describe("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", tools_params)); + assert_equals(std::string("generic tool calls"), describe("tests/chat/templates/google-gemma-7b-it.jinja", tools_params)); + assert_equals(std::string("content-only"), describe("tests/chat/templates/google-gemma-7b-it.jinja", no_tools_params)); + // assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params)); +} static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { fprintf(stderr, "Template source: %s\n", tmpl.source().c_str()); @@ -498,7 +505,7 @@ static void test_grammars() { } int main() { - // test_tool_call_style_detection(); + test_format_detection(); // test_parsing(); test_grammars();