diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 77ae529de..ab0f04b46 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1,3 +1,13 @@ +/* + Tests chat handling, including grammar generation and parsing for tool calling, for various templates. + + Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates, + e.g. given Minja (http://github.com/google/minja) checked out in parent dir: + + cmake -B build && cmake --build build --parallel && \ + ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null + +*/ #include "chat.hpp" #include "chat-template.hpp" #include "llama-grammar.h" @@ -44,7 +54,7 @@ static void assert_equals(const T & expected, const T & actual) { } static std::string read_file(const std::string &path) { - std::cout << "# Reading: " << path << std::endl << std::flush; + std::cerr << "# Reading: " << path << std::endl << std::flush; std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { fs = std::ifstream("../" + path, std::ios_base::binary); @@ -168,13 +178,15 @@ struct delta_data { common_chat_parser parser; }; -static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { +static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) { common_chat_inputs inputs; inputs.parallel_tool_calls = true; inputs.messages = json::array(); inputs.messages.push_back(user_message); inputs.tools = tools; + inputs.tool_choice = tool_choice; auto params_prefix = common_chat_params_init(tmpl, inputs); + inputs.messages.push_back(delta_message); inputs.add_generation_prompt = false; auto params_full = common_chat_params_init(tmpl, inputs); @@ -220,7 +232,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector", ""); std::vector end_tokens { "" }; - assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); - assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file( - "models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file( + "models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), inputs_tools).format); // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser( + assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser( "{\n" " \"response\": \"Hello, world!\"\n" "}")); @@ -339,7 +351,7 @@ static void test_template_output_parsers() { "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); std::vector end_tokens { "" }; - assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message_with_id, tools, @@ -351,11 +363,11 @@ static void test_template_output_parsers() { "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; - assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file( - "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file( - "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file( + "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), inputs_tools).format); + assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file( + "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, @@ -372,9 +384,9 @@ static void test_template_output_parsers() { "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params)); - assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file( - "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file( + "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), inputs_tools).format); // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, @@ -389,7 +401,7 @@ static void test_template_output_parsers() { "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); @@ -401,7 +413,7 @@ static void test_template_output_parsers() { "models/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); @@ -413,8 +425,8 @@ static void test_template_output_parsers() { "models/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params)); - assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "all\n" @@ -428,7 +440,7 @@ static void test_template_output_parsers() { "models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); std::vector end_tokens { "<|eot_id|>" }; - assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); @@ -440,7 +452,7 @@ static void test_template_output_parsers() { "models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); std::vector end_tokens { "<|end▁of▁sentence|>" }; - assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); @@ -452,9 +464,33 @@ static void test_template_output_parsers() { } } -int main() { - test_template_output_parsers(); +int main(int argc, char **argv) { +#ifndef _WIN32 + if (argc > 1) { + common_chat_inputs inputs; + inputs.messages = {{{"role", "user"}, {"content", "Hey"}}}; + inputs.tools = json::array({special_function_tool}); - std::cout << "\n[tool-call] All tests passed!" << std::endl; + std::cout << "| Template | Format |\n"; + std::cout << "|----------|--------|\n"; + + for (int i = 1; i < argc; i++) { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << std::endl; + continue; + } + common_chat_template tmpl(read_file(path), "", ""); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n"; + } + } + else +#endif + { + test_template_output_parsers(); + std::cout << "\n[chat] All tests passed!" << std::endl; + } return 0; }