diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index f78e169ac..ff905ee0b 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -541,31 +541,35 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector first_tool_rules; - std::vector subsequent_tool_rules; - foreach_function(params.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); - data.grammar_triggers.push_back({name, /* .at_start = */ true}); - data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); - }); - auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; - if (params.parallel_tool_calls) { - auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; - builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); - } else { - builder.add_rule("root", first_rule); - } + if (!params.tools.is_null() && !params.tools.empty()) { + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector first_tool_rules; + std::vector subsequent_tool_rules; + foreach_function(params.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); + data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); + }); + auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; + if (params.parallel_tool_calls) { + auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; + builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); + } else { + builder.add_rule("root", first_rule); + } - }, grammar_options); + }, grammar_options); + data.format = "functionary v3.2 tool calls"; + } else { + data.format = "functionary v3.2 content-only"; + } data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); - data.format = "functionary v3.2 tool calls"; data.parser = [params](const std::string & input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); @@ -763,21 +767,24 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat } common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) { - if (params.tools.is_null() || params.tool_choice == "none") { - return common_chat_init_without_tools(tmpl, params); - } - - if (!params.grammar.empty()) { + auto has_tools = params.tools.is_null() || params.tool_choice == "none"; + if (has_tools && !params.grammar.empty()) { throw std::runtime_error("Cannot specify grammar with tools"); } const auto & src = tmpl.source(); + if (src.find(">>>all") != std::string::npos) { + // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when + return common_chat_init_functionary_v3_2_tool_call(tmpl, params); + } + + if (has_tools) { + return common_chat_init_without_tools(tmpl, params); + } + if (src.find("") != std::string::npos) { return common_chat_init_hermes_2_pro_tool_call(tmpl, params); } - if (src.find(">>>all") != std::string::npos) { - return common_chat_init_functionary_v3_2_tool_call(tmpl, params); - } if (src.find("<|start_header_id|>") != std::string::npos && src.find("(); + } + auto has_tool_calls = message.contains("tool_calls"); + if (has_tool_calls) { + for (const auto & tc : message.at("tool_calls")) { + const auto & arguments = tc.at("function").at("arguments"); + ret.tool_calls.push_back({ + tc.at("function").at("name").get(), + arguments.is_string() ? arguments.get() : arguments.dump(), + tc.contains("id") ? tc.at("id").get() : "", + }); + } + } + return ret; +} + template static void assert_equals(const T & expected, const T & actual) { if (expected != actual) { @@ -139,184 +162,13 @@ const auto code_interpreter_tool = json::parse(R"({ const json tools = {special_function_tool, python_tool}; const json llama_3_1_tools = {special_function_tool, code_interpreter_tool}; -// static void test_parsing() { -// json request = { -// {"tools", tools} -// }; +struct delta_data { + std::string delta; + std::string grammar; + common_chat_parser parser; +}; -// const auto fooBarCall = json { -// {"type", "function"}, -// {"function", { -// {"name", "foo"}, -// {"arguments", dump({ -// {"bar", 1} -// })}, -// }} -// }; - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools, -// "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}", -// "", -// json::array({fooBarCall})); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools, -// "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}", -// "", -// json::array({fooBarCall})); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools, -// "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", -// "", -// json::array({fooBarCall})); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools, -// ">>>python\n{\"code\": \"print('Hello, world!')\"}", -// "", -// json {{ -// {"type", "function"}, -// {"function", { -// {"name", "python"}, -// {"arguments", dump({ -// {"code", "print('Hello, world!')"} -// })} -// }} -// }}); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools, -// ">>>special_function\n{\"arg1\": 1}\n ", -// "", -// json {{ -// {"type", "function"}, -// {"function", { -// {"name", "special_function"}, -// {"arguments", dump({ -// {"arg1", 1} -// })} -// }} -// }}); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools, -// "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", -// "Hello, world!", -// json { -// { -// {"type", "function"}, -// {"function", { -// {"name", "foo"}, -// {"arguments", dump({ -// {"arg1", 1} -// })} -// }} -// }, -// { -// {"type", "function"}, -// {"function", { -// {"name", "bar"}, -// {"arguments", dump({ -// {"arg2", 2} -// })} -// }} -// }, -// }); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools, -// "{ } ", -// " ", -// json {{ -// {"type", "function"}, -// {"function", { -// {"name", "test"}, -// {"arguments", "{}"} -// }} -// }}); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "<|python_tag|>this could be anything", -// "", -// json {{ -// {"type", "function"}, -// {"function", { -// {"name", "python"}, -// {"arguments", "this could be anything"}, -// }} -// }}); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "I'm thinking<|python_tag|>", -// "I'm thinking", -// json {{ -// {"type", "function"}, -// {"function", { -// {"name", "python"}, -// {"arguments", ""}, -// }} -// }}); -// auto special_function_call = json { -// {"type", "function"}, -// {"function", { -// {"arguments", dump({{"arg1", 1}})}, -// {"name", "special_function"}, -// }}, -// }; -// auto special_function_call_with_id = json::parse(special_function_call.dump()); -// special_function_call_with_id["id"] = "123456789"; - -// auto no_function_call = json::array(); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}", -// "", -// json::array({{ -// {"type", "function"}, -// {"function", { -// {"arguments", dump({{"code", "print('Hey')"}})}, -// {"name", "python"}, -// }} -// }})); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", -// "", -// json::array({special_function_call})); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", -// "", -// json::array({special_function_call})); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", -// "", -// json::array({special_function_call})); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", -// "", -// json::array({special_function_call})); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", -// "", -// json::array({special_function_call})); - -// // No match: function unknown -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", -// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", -// no_function_call); -// // No match: bad indentation -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", -// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", -// no_function_call); -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, -// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", -// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", -// no_function_call); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools, -// "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", -// "Bleh", -// json::array({special_function_call_with_id})); - -// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools, -// "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]", -// "Bleh", -// json::array({special_function_call})); -// } - -static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { +static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { fprintf(stderr, "Template source: %s\n", tmpl.source().c_str()); fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str()); @@ -325,13 +177,17 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c params.messages = json::array(); params.messages.push_back(user_message); params.tools = tools; - std::string prefix = common_chat_init(tmpl, params).prompt; + auto prefix_data = common_chat_init(tmpl, params); params.messages.push_back(delta_message); params.add_generation_prompt = false; - std::string full = common_chat_init(tmpl, params).prompt; + auto full_data = common_chat_init(tmpl, params); + + std::string prefix = prefix_data.prompt; + std::string full = full_data.prompt; // Check full starts with prefix if (full.find(prefix) != 0) { + fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str()); throw std::runtime_error("Full message does not start with prefix"); } @@ -350,27 +206,12 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c break; } } - return delta; + return {delta, full_data.grammar, full_data.parser}; } -static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) { +static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) { // auto tool_call_style = common_tool_call_style_detect(tmpl); - common_chat_msg expected_msg { - "assistant", - "", - {}, - }; - auto has_tool_calls = test_message.contains("tool_calls"); - if (has_tool_calls) { - for (const auto & tc : test_message.at("tool_calls")) { - const auto & arguments = tc.at("function").at("arguments"); - expected_msg.tool_calls.push_back({ - tc.at("function").at("name").get(), - arguments.is_string() ? arguments.get() : arguments.dump(), - tc.contains("id") ? tc.at("id").get() : "", - }); - } - } + common_chat_msg expected_msg = msg_from_json(test_message); // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, // get the diff and try and parse it w/ the grammar. @@ -385,35 +226,37 @@ static void test_template(const common_chat_template & tmpl, const std::vector().c_str()); - if (has_tool_calls) { - auto grammar = build_grammar(chat_data.grammar); + + auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools); + std::cout << "Full delta:\n```\n" << data.delta << "\n```" << std::endl; + if (!expected_delta.empty()) { + assert_equals(expected_delta, data.delta); + } + + if (!skip_parser_test) { + const auto msg = data.parser(data.delta); + assert_msg_equals(expected_msg, msg); + } + + if (!expected_msg.tool_calls.empty()) { + GGML_ASSERT(!data.grammar.empty()); + } + if (!data.grammar.empty()) { + auto grammar = build_grammar(data.grammar); if (!grammar) { throw std::runtime_error("Failed to build grammar"); } - + // TODO: exercice lazy grammars + triggers here, instead of skipping the test if (!skip_grammar_test) { - auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools); - std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; - - const auto msg = chat_data.parser(full_delta); - assert_msg_equals(expected_msg, msg); - - auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { - {"role", "assistant"}, - {"content", {}}, - {"tool_calls", test_message.at("tool_calls")} - }, tools); - if (!match_string(content_less_delta, grammar.get())) { - throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar); + if (!match_string(data.delta, grammar.get())) { + throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + "\n\nGrammar: " + data.grammar); } } } } } -static void test_grammars() { +static void test_template_output_parsers() { auto text_message = json { {"role", "assistant"}, {"content", "Hello, world!"}, @@ -465,6 +308,7 @@ static void test_grammars() { common_chat_params tools_params = no_tools_params; tools_params.tools = json::array(); + tools_params.tools.push_back(special_function_tool); auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) { auto data = common_chat_init(tmpl, params); @@ -477,114 +321,182 @@ static void test_grammars() { assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message_with_id, tools); + // Generic tool calls doesn't generate / parse content-only messages symmetrically. + assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser( + "{\n" + " \"response\": \"Hello, world!\"\n" + "}")); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools, + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}"); } { const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""); std::vector end_tokens { "<|end|>" }; assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message_with_id, tools); + assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); + // Generic tool calls doesn't generate / parse content-only messages symmetrically. + assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser( + "{\n" + " \"response\": \"Hello, world!\"\n" + "}")); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools, + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}"); } { const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); std::vector end_tokens { "" }; assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools, + "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]", + /* skip_grammar_test= */ true); } { const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); - test_template(tmpl, end_tokens, python_tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + ""); + test_template(tmpl, end_tokens, python_tool_call_message, tools, + "\n" + "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" + ""); } { const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + ""); } { const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + ""); } { const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - // assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); - // test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools); - test_template(tmpl, end_tokens, python_tool_call_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); + // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, + "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); + test_template(tmpl, end_tokens, python_tool_call_message, tools, + "<|python_tag|>python.call(code=\"print('hey')\")"); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools); } - { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); - std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - - assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); - } { const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"arg1\": 1}"); } { const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params)); + assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params)); + test_template(tmpl, end_tokens, text_message, tools, + "all\n" + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "special_function\n" + "{\"arg1\": 1}"); } { const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); std::vector end_tokens { "<|eot_id|>" }; assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); std::vector end_tokens { "<|end▁of▁sentence|>" }; assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); - test_template(tmpl, end_tokens, tool_call_message, tools); + test_template(tmpl, end_tokens, text_message, tools, + "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|>"); } } int main() { // test_parsing(); - test_grammars(); + test_template_output_parsers(); std::cout << "\n[tool-call] All tests passed!" << std::endl; return 0;