diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index bd09d0742..aef5fbd22 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -319,32 +319,10 @@ static void test_template_output_parsers() { const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "", ""); std::vector end_tokens { "" }; + assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); - // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser( - "{\n" - " \"response\": \"Hello, world!\"\n" - "}")); - test_template(tmpl, end_tokens, tool_call_message_with_id, tools, - "{\n" - " \"tool_calls\": [\n" - " {\n" - " \"name\": \"special_function\",\n" - " \"arguments\": {\n" - " \"arg1\": 1\n" - " },\n" - " \"id\": \"123456789\"\n" - " }\n" - " ]\n" - "}"); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""); - std::vector end_tokens { "<|end|>" }; + assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), tools_params)); - assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); // Generic tool calls doesn't generate / parse content-only messages symmetrically. assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser( "{\n" @@ -368,16 +346,20 @@ static void test_template_output_parsers() { std::vector end_tokens { "" }; assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message_with_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]", /* skip_grammar_test= */ true); } { - const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); + const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), tools_params)); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, "\n" @@ -388,35 +370,13 @@ static void test_template_output_parsers() { "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" ""); } - { - const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); - std::vector end_tokens { "<|im_end|>" }; - - assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools, - "Hello, world!", /* skip_grammar_test= */ true); - test_template(tmpl, end_tokens, tool_call_message, tools, - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - ""); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""); - std::vector end_tokens { "<|im_end|>" }; - - assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools, - "Hello, world!", /* skip_grammar_test= */ true); - test_template(tmpl, end_tokens, tool_call_message, tools, - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - ""); - } { const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); + // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); @@ -424,44 +384,36 @@ static void test_template_output_parsers() { "<|python_tag|>python.call(code=\"print('hey')\")"); test_template(tmpl, end_tokens, tool_call_message, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); - test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); - std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - - assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools, - "Hello, world!", /* skip_grammar_test= */ true); - test_template(tmpl, end_tokens, tool_call_message, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); - } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); - std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - - assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools, - "Hello, world!", /* skip_grammar_test= */ true); - test_template(tmpl, end_tokens, tool_call_message, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, "{\"arg1\": 1}"); } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); + + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); + } { const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params)); - assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, "all\n" "Hello, world!", /* skip_grammar_test= */ true); @@ -474,6 +426,7 @@ static void test_template_output_parsers() { std::vector end_tokens { "<|eot_id|>" }; assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, @@ -484,6 +437,7 @@ static void test_template_output_parsers() { std::vector end_tokens { "<|end▁of▁sentence|>" }; assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools,