diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index f28784ccb..cccc98db8 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -298,34 +298,6 @@ const json tools = {special_function_tool, python_tool}; // json::array({special_function_call})); // } -static void test_format_detection() { - common_chat_params no_tools_params; - no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}}; - - common_chat_params tools_params = no_tools_params; - tools_params.tools = json::array(); - - auto describe = [](const std::string & template_file, const common_chat_params & params) { - const common_chat_template tmpl(read_file(template_file), "", ""); - auto data = common_chat_init(tmpl, params); - return data.format; - }; - - assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", tools_params)); - assert_equals(std::string("functionary v3.2 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", tools_params)); - assert_equals(std::string("firefunction v2 tool calls"), describe("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", tools_params)); - assert_equals(std::string("llama 3.1 tool calls"), describe("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", tools_params)); - assert_equals(std::string("llama 3.2 tool calls"), describe("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", tools_params)); - assert_equals(std::string("mistral nemo tool calls"), describe("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", tools_params)); - assert_equals(std::string("deepseek r1 tool calls"), describe("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", tools_params)); - assert_equals(std::string("generic tool calls"), describe("tests/chat/templates/google-gemma-7b-it.jinja", tools_params)); - assert_equals(std::string("content-only"), describe("tests/chat/templates/google-gemma-7b-it.jinja", no_tools_params)); - // assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params)); -} - static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { fprintf(stderr, "Template source: %s\n", tmpl.source().c_str()); fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str()); @@ -363,20 +335,23 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c return delta; } -static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { +static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) { // auto tool_call_style = common_tool_call_style_detect(tmpl); common_chat_msg expected_msg { "assistant", "", {}, }; - for (const auto & tc : tool_calling_message.at("tool_calls")) { - const auto & arguments = tc.at("function").at("arguments"); - expected_msg.tool_calls.push_back({ - tc.at("function").at("name").get(), - arguments.is_string() ? arguments.get() : arguments.dump(), - tc.contains("id") ? tc.at("id").get() : "", - }); + auto has_tool_calls = test_message.contains("tool_calls"); + if (has_tool_calls) { + for (const auto & tc : test_message.at("tool_calls")) { + const auto & arguments = tc.at("function").at("arguments"); + expected_msg.tool_calls.push_back({ + tc.at("function").at("name").get(), + arguments.is_string() ? arguments.get() : arguments.dump(), + tc.contains("id") ? tc.at("id").get() : "", + }); + } } // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, @@ -386,36 +361,45 @@ static void test_template(const common_chat_template & tmpl, const std::vector().c_str()); - auto grammar = build_grammar(chat_data.grammar); - if (!grammar) { - throw std::runtime_error("Failed to build grammar"); - } + for (const auto & tool_choice : json({"auto", "required"})) { + common_chat_params params; + params.tool_choice = tool_choice; + params.parallel_tool_calls = true; + params.messages = json {user_message, test_message}; + params.tools = tools; + auto chat_data = common_chat_init(tmpl, params); + // fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get().c_str()); + if (has_tool_calls) { + auto grammar = build_grammar(chat_data.grammar); + if (!grammar) { + throw std::runtime_error("Failed to build grammar"); + } - if (!skip_grammar_test) { - auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools); - std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; + if (!skip_grammar_test) { + auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools); + std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; - const auto msg = chat_data.parser->parse_final(full_delta); - assert_msg_equals(expected_msg, msg); + const auto msg = chat_data.parser->parse_final(full_delta); + assert_msg_equals(expected_msg, msg); - auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { - {"role", "assistant"}, - {"content", {}}, - {"tool_calls", tool_calling_message.at("tool_calls")} - }, tools); - if (!match_string(content_less_delta, grammar.get())) { - throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar); + auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { + {"role", "assistant"}, + {"content", {}}, + {"tool_calls", test_message.at("tool_calls")} + }, tools); + if (!match_string(content_less_delta, grammar.get())) { + throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar); + } + } } } } static void test_grammars() { + auto text_message = json { + {"role", "assistant"}, + {"content", "Hello, world!"}, + }; auto tool_call_message = json { {"role", "assistant"}, {"content", {}}, @@ -444,68 +428,128 @@ static void test_grammars() { }}} }; - { - const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); - test_template(tmpl, { "" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); - // assert_equals(tmpl.requires_object_arguments_, true); - test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); - test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); - test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""); - test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); - test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); - test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); - test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); - test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); - test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); - test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); - test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools); - } + + common_chat_params no_tools_params; + no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}}; + + common_chat_params tools_params = no_tools_params; + tools_params.tools = json::array(); + + auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) { + auto data = common_chat_init(tmpl, params); + return data.format; + }; + { const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "", ""); - test_template(tmpl, { "" }, tool_call_message_with_id, tools); + std::vector end_tokens { "" }; + + assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools); } { const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""); - test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools); + std::vector end_tokens { "<|end|>" }; + + assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + std::vector end_tokens { "" }; + + assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); + std::vector end_tokens { "<|im_end|>" }; + + assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, python_tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + std::vector end_tokens { "<|im_end|>" }; + + assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""); + std::vector end_tokens { "<|im_end|>" }; + + assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, python_tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + std::vector end_tokens { "<|eot_id|>" }; + + assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); } { const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); - test_template(tmpl, { "<|end▁of▁sentence|>" }, tool_call_message, tools); + std::vector end_tokens { "<|end▁of▁sentence|>" }; + + assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, tool_call_message, tools); } } int main() { - test_format_detection(); // test_parsing(); test_grammars();