From 5ec4c5e4d33c928ee5179aa1d1149d93392af543 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 26 Jan 2025 21:38:07 +0000 Subject: [PATCH] reshuffle chat handlers --- common/chat-handler.cpp | 144 +++++++----- common/chat-handler.hpp | 6 +- tests/test-chat-handler.cpp | 425 +++++++++++++++++++----------------- 3 files changed, 318 insertions(+), 257 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 0c0aba5e9..effeeefda 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -76,7 +76,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri if (type == "function") { tool_names.push_back(tool["function"]["name"]); } else if (type == "code_interpreter") { - tool_names.push_back("ipython"); + tool_names.push_back("python"); } } } @@ -171,6 +171,10 @@ public: /* .tool_calls = */ {}, }; } + + std::unique_ptr clone() const override { + return std::make_unique(); + } }; class monolithic_chat_parser : public common_chat_parser { @@ -192,13 +196,48 @@ public: input_buffer_.clear(); return out; } + + std::unique_ptr clone() const override { + return std::make_unique(parse_final_); + } }; -static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { +const auto python_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "python", + "description": "an ipython interpreter", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + } + } +})"); + +static void foreach_normalized_tool(const json & tools, const std::function & fn) { + for (const auto & tool : tools) { + if (!tool.contains("type")) { + continue; + } + if (tool["type"] == "code_interpreter") { + fn(python_tool); + } else { + fn(tool); + } + } +} + +static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { common_chat_data data; auto tool_call_schemas = json::array(); - for (const auto & tool : params.tools) { + foreach_normalized_tool(params.tools, [&](const json & tool) { const auto & function = tool["function"]; auto tool_schema = json { {"type", "object"}, @@ -222,7 +261,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa tool_schema["required"].push_back("id"); } tool_call_schemas.emplace_back(tool_schema); - } + }); const auto tool_call = params.parallel_tool_calls ? json { @@ -276,7 +315,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa "Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```"); data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([&](const std::string & input) { + data.parser = std::make_unique([&](const std::string & input) { json data = json::parse(input); common_chat_msg result; result.role = "assistant"; @@ -303,13 +342,11 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa return data; } -static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { common_chat_data data; - auto builtin_tools = json {"wolfram_alpha", "brave_search"}; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); - for (const auto & tool : params.tools) { + foreach_normalized_tool(params.tools, [&](const json & tool) { const auto & function = tool["function"]; schemas.push_back({ {"type", "object"}, @@ -329,7 +366,7 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t }}, {"required", json::array({"name", "arguments", "id"})}, }); - } + }); auto schema = json { {"type", "array"}, {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, @@ -344,24 +381,14 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); } data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([](const std::string & input) -> common_chat_msg { return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); }); return data; } -static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { +static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { auto builtin_tools = json {"wolfram_alpha", "brave_search"}; - for (const auto & tool : params.tools) { - if (!tool.contains("type")) { - continue; - } - if (tool["type"] == "code_interpreter") { - builtin_tools.push_back("code_interpreter"); - break; - } - } - common_chat_data data; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -375,6 +402,7 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ } if (tool["type"] == "code_interpreter") { + builtin_tools.push_back("code_interpreter"); has_python = true; } else if (tool["type"] == "function" && tool.contains("function")) { const auto & function = tool["function"]; @@ -422,8 +450,10 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([params, uses_python_tag](const std::string & input) -> common_chat_msg { + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, { + {"builtin_tools", builtin_tools}, + }); + data.parser = std::make_unique([params, uses_python_tag](const std::string & input) -> common_chat_msg { if (uses_python_tag) { static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; @@ -448,11 +478,11 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ return data; } -static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { common_chat_data data; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); - for (const auto & tool : params.tools) { + foreach_normalized_tool(params.tools, [&](const json & tool) { const auto & function = tool["function"]; schemas.push_back({ {"type", "object"}, @@ -465,7 +495,7 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha }}, {"required", json::array({"name", "arguments", "id"})}, }); - } + }); auto schema = json { {"type", "array"}, {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, @@ -480,13 +510,13 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); } data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([](const std::string & input) -> common_chat_msg { return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); }); return data; } -static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_data data; @@ -530,7 +560,7 @@ static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const com }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); @@ -538,11 +568,12 @@ static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const com return data; } -static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // TODO: handle tool {type: code_interpreter} as python common_chat_data data; + json tools = params.tools.is_null() ? params.tools : json::array(); data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; @@ -578,7 +609,7 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; @@ -602,12 +633,12 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c return data; } -static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { common_chat_data data; // (content)?({"name": "foo", "arguments": {"a": 1}})* data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - for (const auto & tool : params.tools) { + foreach_normalized_tool(params.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; @@ -620,8 +651,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t }}, {"required", json::array({"name", "arguments"})}, })); - } - + }); auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); if (params.tool_choice != "required") { @@ -630,7 +660,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique([&](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([&](const std::string & input) -> common_chat_msg { try { std::regex start_pattern(R"([\n\s]*)"); std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); @@ -677,24 +707,40 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t return data; } +static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) { + common_chat_data data; + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.parser = std::make_unique(); + if (!params.json_schema.is_null()) { + if (!params.grammar.empty()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + data.grammar = json_schema_to_grammar(params.json_schema); + } else { + data.grammar = params.grammar.empty(); + } + return data; +} + common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) { if (params.tools.is_null()) { - common_chat_data data; - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.handler = std::make_unique(); - return data; + return common_chat_init_without_tools(tmpl, params); } - const auto & src = tmpl.source(); + if (!params.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + + const auto & src = tmpl.source(); if (src.find("") != std::string::npos) { - return build_hermes_2_pro_tool_call_handler(tmpl, params); + return common_chat_init_hermes_2_pro_tool_call(tmpl, params); } if (src.find(">>>all") != std::string::npos) { - return build_functionary_v3_llama_3_tool_call_handler(tmpl, params); + return common_chat_init_functionary_v3_llama_3_tool_call(tmpl, params); } if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos; @@ -705,16 +751,16 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc // TODO: make this conditional on a very small model (e.g. 1B / 3B). auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; - return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json); + return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json); } // if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { // TODO: Command-R-Plus // } if (src.find("[TOOL_CALLS]") != std::string::npos) { - return build_mistral_nemo_tool_call_handler(tmpl, params); + return common_chat_init_mistral_nemo_tool_call(tmpl, params); } if (src.find(" functools[") != std::string::npos) { - return build_firefunction_v2_tool_call_handler(tmpl, params); + return common_chat_init_firefunction_v2_tool_call(tmpl, params); } - return build_generic_tool_call_handler(tmpl, params); + return common_chat_init_generic_tool_call(tmpl, params); } diff --git a/common/chat-handler.hpp b/common/chat-handler.hpp index 91304ab7e..bff810e58 100644 --- a/common/chat-handler.hpp +++ b/common/chat-handler.hpp @@ -22,6 +22,7 @@ struct common_chat_params { json json_schema; bool parallel_tool_calls; bool stream; + std::string grammar; }; class common_chat_parser { @@ -30,14 +31,15 @@ public: virtual std::optional parse_partial(const std::string & input) = 0; virtual common_chat_msg parse_final(const std::string & input) = 0; + virtual std::unique_ptr clone() const = 0; }; struct common_chat_data { - std::string prompt; + json prompt; std::string grammar; std::vector grammar_triggers; std::vector additional_stops; - std::unique_ptr handler; + std::unique_ptr parser; }; struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params); diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index cb42e9b49..e787601e6 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -1,4 +1,5 @@ #include "chat-handler.hpp" +#include "chat-template.hpp" #include "llama-grammar.h" #include "unicode.h" @@ -71,9 +72,7 @@ static std::string dump(const json & j) { return minja::Value(j).dump(-1, /* to_json= */ true); } -static void test_parse_tool_call(common_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { - std::cout << "# Testing: " << input << std::endl << std::flush; - auto result = parse_tool_calls(style, tools, input); +static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) { assert_equals(expected_content, result.content); auto tool_calls = json::array(); for (const auto & tc : result.tool_calls) { @@ -140,209 +139,216 @@ const auto code_interpreter_tool = json::parse(R"({ })"); const json tools = {special_function_tool, code_interpreter_tool}; -static void test_parsing() { - json request = { - {"tools", tools} - }; +// static void test_parsing() { +// json request = { +// {"tools", tools} +// }; - const auto fooBarCall = json { - {"type", "function"}, - {"function", { - {"name", "foo"}, - {"arguments", dump({ - {"bar", 1} - })}, - }} - }; +// const auto fooBarCall = json { +// {"type", "function"}, +// {"function", { +// {"name", "foo"}, +// {"arguments", dump({ +// {"bar", 1} +// })}, +// }} +// }; - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools, - "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}", - "", - json::array({fooBarCall})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools, - "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}", - "", - json::array({fooBarCall})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools, +// "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}", +// "", +// json::array({fooBarCall})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools, +// "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}", +// "", +// json::array({fooBarCall})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools, - "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", - "", - json::array({fooBarCall})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools, +// "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", +// "", +// json::array({fooBarCall})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools, - ">>>python\n{\"code\": \"print('Hello, world!')\"}", - "", - json {{ - {"type", "function"}, - {"function", { - {"name", "python"}, - {"arguments", dump({ - {"code", "print('Hello, world!')"} - })} - }} - }}); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools, - ">>>special_function\n{\"arg1\": 1}\n ", - "", - json {{ - {"type", "function"}, - {"function", { - {"name", "special_function"}, - {"arguments", dump({ - {"arg1", 1} - })} - }} - }}); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools, +// ">>>python\n{\"code\": \"print('Hello, world!')\"}", +// "", +// json {{ +// {"type", "function"}, +// {"function", { +// {"name", "python"}, +// {"arguments", dump({ +// {"code", "print('Hello, world!')"} +// })} +// }} +// }}); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools, +// ">>>special_function\n{\"arg1\": 1}\n ", +// "", +// json {{ +// {"type", "function"}, +// {"function", { +// {"name", "special_function"}, +// {"arguments", dump({ +// {"arg1", 1} +// })} +// }} +// }}); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools, - "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", - "Hello, world!", - json { - { - {"type", "function"}, - {"function", { - {"name", "foo"}, - {"arguments", dump({ - {"arg1", 1} - })} - }} - }, - { - {"type", "function"}, - {"function", { - {"name", "bar"}, - {"arguments", dump({ - {"arg2", 2} - })} - }} - }, - }); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools, - "{ } ", - " ", - json {{ - {"type", "function"}, - {"function", { - {"name", "test"}, - {"arguments", "{}"} - }} - }}); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools, +// "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", +// "Hello, world!", +// json { +// { +// {"type", "function"}, +// {"function", { +// {"name", "foo"}, +// {"arguments", dump({ +// {"arg1", 1} +// })} +// }} +// }, +// { +// {"type", "function"}, +// {"function", { +// {"name", "bar"}, +// {"arguments", dump({ +// {"arg2", 2} +// })} +// }} +// }, +// }); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools, +// "{ } ", +// " ", +// json {{ +// {"type", "function"}, +// {"function", { +// {"name", "test"}, +// {"arguments", "{}"} +// }} +// }}); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "<|python_tag|>this could be anything", - "", - json {{ - {"type", "function"}, - {"function", { - {"name", "python"}, - {"arguments", "this could be anything"}, - }} - }}); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "I'm thinking<|python_tag|>", - "I'm thinking", - json {{ - {"type", "function"}, - {"function", { - {"name", "python"}, - {"arguments", ""}, - }} - }}); - auto special_function_call = json { - {"type", "function"}, - {"function", { - {"arguments", dump({{"arg1", 1}})}, - {"name", "special_function"}, - }}, - }; - auto special_function_call_with_id = json::parse(special_function_call.dump()); - special_function_call_with_id["id"] = "123456789"; +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "<|python_tag|>this could be anything", +// "", +// json {{ +// {"type", "function"}, +// {"function", { +// {"name", "python"}, +// {"arguments", "this could be anything"}, +// }} +// }}); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "I'm thinking<|python_tag|>", +// "I'm thinking", +// json {{ +// {"type", "function"}, +// {"function", { +// {"name", "python"}, +// {"arguments", ""}, +// }} +// }}); +// auto special_function_call = json { +// {"type", "function"}, +// {"function", { +// {"arguments", dump({{"arg1", 1}})}, +// {"name", "special_function"}, +// }}, +// }; +// auto special_function_call_with_id = json::parse(special_function_call.dump()); +// special_function_call_with_id["id"] = "123456789"; - auto no_function_call = json::array(); +// auto no_function_call = json::array(); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}", - "", - json::array({{ - {"type", "function"}, - {"function", { - {"arguments", dump({{"code", "print('Hey')"}})}, - {"name", "python"}, - }} - }})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - "", - json::array({special_function_call})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - "", - json::array({special_function_call})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - "", - json::array({special_function_call})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - "", - json::array({special_function_call})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - "", - json::array({special_function_call})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}", +// "", +// json::array({{ +// {"type", "function"}, +// {"function", { +// {"arguments", dump({{"code", "print('Hey')"}})}, +// {"name", "python"}, +// }} +// }})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", +// "", +// json::array({special_function_call})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", +// "", +// json::array({special_function_call})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", +// "", +// json::array({special_function_call})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", +// "", +// json::array({special_function_call})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", +// "", +// json::array({special_function_call})); - // No match: function unknown - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - no_function_call); - // No match: bad indentation - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - no_function_call); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, - "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - no_function_call); +// // No match: function unknown +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", +// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", +// no_function_call); +// // No match: bad indentation +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", +// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", +// no_function_call); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, +// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", +// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", +// no_function_call); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools, - "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", - "Bleh", - json::array({special_function_call_with_id})); +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools, +// "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", +// "Bleh", +// json::array({special_function_call_with_id})); - test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools, - "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]", - "Bleh", - json::array({special_function_call})); -} +// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools, +// "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]", +// "Bleh", +// json::array({special_function_call})); +// } -static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) { - const common_chat_template tmpl(read_file(template_file), "", ""); - auto tool_call_style = common_tool_call_style_detect(tmpl); - std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; - assert_equals(expected, tool_call_style); -} +// static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) { +// const common_chat_template tmpl(read_file(template_file), "", ""); +// auto tool_call_style = common_tool_call_style_detect(tmpl); +// std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; +// assert_equals(expected, tool_call_style); +// } -static void test_tool_call_style_detection() { - test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1); - test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3); - test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2); - test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1); - test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2); - test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); - test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); - test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); - test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS); - test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO); - test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC); -} +// static void test_tool_call_style_detection() { +// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1); +// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3); +// test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2); +// test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1); +// test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2); +// test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); +// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); +// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO); +// test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS); +// test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO); +// test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC); +// } static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { fprintf(stderr, "Template source: %s\n", tmpl.source().c_str()); fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str()); - auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); - auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); + + common_chat_params params; + params.parallel_tool_calls = true; + params.messages = json::array(); + params.messages.push_back(user_message); + params.tools = tools; + std::string prefix = common_chat_init(tmpl, params).prompt; + params.messages.push_back(delta_message); + std::string full = common_chat_init(tmpl, params).prompt; // Check full starts with prefix if (full.find(prefix) != 0) { @@ -364,7 +370,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c } static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { - auto tool_call_style = common_tool_call_style_detect(tmpl); + // auto tool_call_style = common_tool_call_style_detect(tmpl); auto & tool_calls = tool_calling_message.at("tool_calls"); // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, @@ -374,8 +380,13 @@ static void test_template(const common_chat_template & tmpl, const std::vector().c_str()); + auto grammar = build_grammar(chat_data.grammar); if (!grammar) { throw std::runtime_error("Failed to build grammar"); } @@ -383,7 +394,9 @@ static void test_template(const common_chat_template & tmpl, const std::vectorparse_final(full_delta); + assert_msg_equals(msg, "", tool_calls); auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { {"role", "assistant"}, @@ -391,7 +404,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector", ""); - // test_template(tmpl, { "" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); - // } + { + const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + test_template(tmpl, { "" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); + } // { // const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); // assert_equals(tmpl.requires_object_arguments_, true); @@ -457,14 +470,14 @@ static void test_grammars() { // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); // } - // { - // const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); - // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - // } - // { - // const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); - // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - // } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); + test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); + test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + } { const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);