From 39729457980a0a5ac9bf8e3a38b356d1723c3197 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 01:54:08 +0000 Subject: [PATCH] common_tool_call rename --- common/tool-call.cpp | 100 ++++++++++++++++++------------------- common/tool-call.h | 20 ++++---- examples/server/server.cpp | 14 +++--- examples/server/utils.hpp | 6 +-- tests/test-tool-call.cpp | 50 +++++++++---------- 5 files changed, 95 insertions(+), 95 deletions(-) diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 0c2e802bd..a3ea52290 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -47,34 +47,34 @@ static json normalize_tools(const json & tools) { return results; } -std::string llama_tool_call_style_name(llama_tool_call_style style) { +std::string common_tool_call_style_name(common_tool_call_style style) { switch (style) { - case llama_tool_call_style::None: + case common_tool_call_style::None: return "None"; - case llama_tool_call_style::Generic: + case common_tool_call_style::Generic: return "Generic"; - case llama_tool_call_style::Llama31: + case common_tool_call_style::Llama31: return "Llama-3.1"; - case llama_tool_call_style::Llama32: + case common_tool_call_style::Llama32: return "Llama-3.2"; - case llama_tool_call_style::FunctionaryV3Llama3: + case common_tool_call_style::FunctionaryV3Llama3: return "FunctionaryV3Llama3"; - case llama_tool_call_style::FunctionaryV3Llama31: + case common_tool_call_style::FunctionaryV3Llama31: return "FunctionaryV3Llama3.1"; - case llama_tool_call_style::Hermes2Pro: + case common_tool_call_style::Hermes2Pro: return "Hermes2Pro"; - case llama_tool_call_style::CommandRPlus: + case common_tool_call_style::CommandRPlus: return "CommandRPlus"; - case llama_tool_call_style::MistralNemo: + case common_tool_call_style::MistralNemo: return "MistralNemo"; - case llama_tool_call_style::FirefunctionV2: + case common_tool_call_style::FirefunctionV2: return "FirefunctionV2"; default: return "Unknown"; } } -llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template) { +common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template) { const auto & src = chat_template.source(); if (src.find("") != std::string::npos) { @@ -150,10 +150,10 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Aggregates the prefix, suffix and in-between text into the content. */ -static llama_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { +static common_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { std::smatch match; - llama_tool_calls result; + common_tool_calls result; auto end = input.end(); auto it = input.begin(); @@ -202,7 +202,7 @@ static llama_tool_calls parse_json_tool_calls(const json & tools, const std::str return result; } -static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { +static common_tool_calls parse_hermes_tool_calls(const std::string& input) { try { std::regex start_pattern(R"([\n\s]*)"); std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); @@ -215,7 +215,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { return {input, {}}; } - llama_tool_calls result; + common_tool_calls result; result.content = rit->prefix(); auto it = rit->suffix().first; @@ -246,7 +246,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { } } -static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { +static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { if (allow_python_tag) { static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; @@ -268,7 +268,7 @@ static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std:: return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } -static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { +static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; @@ -289,15 +289,15 @@ static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & t return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false); } -static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { +static common_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } -static llama_tool_calls parse_generic_tool_calls(const std::string& input) { +static common_tool_calls parse_generic_tool_calls(const std::string& input) { json data = json::parse(input); - llama_tool_calls result; + common_tool_calls result; if (data.contains("tool_calls")) { for (const auto & tool_call : data["tool_calls"]) { result.tool_calls.push_back({ @@ -319,11 +319,11 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) { return result; } -static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { +static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { auto content_end = input.find(prefix); size_t tc_start = std::string::npos; - llama_tool_calls result; + common_tool_calls result; const auto process_tool_calls = [&](const json & tool_calls) { for (const auto & tool_call : tool_calls) { const auto & arguments = tool_call["arguments"]; @@ -345,34 +345,34 @@ static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& i return result; } -static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { +static common_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); } -static llama_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) { +static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) { return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); } -llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { - fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", llama_tool_call_style_name(style).c_str(), input.c_str()); +common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) { + fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str()); switch (style) { - case llama_tool_call_style::None: + case common_tool_call_style::None: return {input, {}}; - case llama_tool_call_style::Generic: + case common_tool_call_style::Generic: return parse_generic_tool_calls(input); - case llama_tool_call_style::Llama31: + case common_tool_call_style::Llama31: return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true); - case llama_tool_call_style::Llama32: + case common_tool_call_style::Llama32: return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false); - case llama_tool_call_style::FunctionaryV3Llama3: + case common_tool_call_style::FunctionaryV3Llama3: return parse_functionary_v3_tool_calls(tools, input); - case llama_tool_call_style::FunctionaryV3Llama31: + case common_tool_call_style::FunctionaryV3Llama31: return parse_functionary_v3_llama_3_1_tool_calls(tools, input); - case llama_tool_call_style::Hermes2Pro: + case common_tool_call_style::Hermes2Pro: return parse_hermes_tool_calls(input); - case llama_tool_call_style::MistralNemo: + case common_tool_call_style::MistralNemo: return parse_mistral_nemo_tool_calls(input); - case llama_tool_call_style::FirefunctionV2: + case common_tool_call_style::FirefunctionV2: return parse_firefunction_v2_tool_calls(input); default: throw std::runtime_error("Unsupported tool call style"); @@ -397,8 +397,8 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages return messages_with_system; } -llama_tool_call_handler llama_tool_call_handler_init( - llama_tool_call_style style, +common_tool_call_handler common_tool_call_handler_init( + common_tool_call_style style, const common_chat_template & tmpl, bool allow_content, const nlohmann::ordered_json & parallel_tool_calls, @@ -406,14 +406,14 @@ llama_tool_call_handler llama_tool_call_handler_init( const nlohmann::ordered_json & tools, const nlohmann::ordered_json & json_schema) { - llama_tool_call_handler handler; + common_tool_call_handler handler; auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get(); switch (style) { - case llama_tool_call_style::None: + case common_tool_call_style::None: handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); break; - case llama_tool_call_style::Generic: { + case common_tool_call_style::Generic: { auto actual_tools = normalize_tools(tools); auto tool_call_schemas = json::array(); for (const auto & tool : actual_tools) { @@ -493,7 +493,7 @@ llama_tool_call_handler llama_tool_call_handler_init( handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); break; } - case llama_tool_call_style::MistralNemo: { + case common_tool_call_style::MistralNemo: { auto actual_tools = normalize_tools(tools); handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { auto schemas = json::array(); @@ -534,7 +534,7 @@ llama_tool_call_handler llama_tool_call_handler_init( handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); break; } - case llama_tool_call_style::FirefunctionV2: { + case common_tool_call_style::FirefunctionV2: { auto actual_tools = normalize_tools(tools); handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { auto schemas = json::array(); @@ -568,8 +568,8 @@ llama_tool_call_handler llama_tool_call_handler_init( handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); break; } - case llama_tool_call_style::Llama31: - case llama_tool_call_style::Llama32: { + case common_tool_call_style::Llama31: + case common_tool_call_style::Llama32: { auto builtin_tools = json {"wolfram_alpha", "brave_search"}; for (const auto & tool : tools) { if (!tool.contains("type")) { @@ -582,13 +582,13 @@ llama_tool_call_handler llama_tool_call_handler_init( } auto actual_tools = normalize_tools(tools); - auto uses_python_tag = style == llama_tool_call_style::Llama31; + auto uses_python_tag = style == common_tool_call_style::Llama31; // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon // as it seems to be outputting some JSON. // TODO: make this conditional on a very small model (e.g. 1B / 3B). - auto eagerly_match_any_json = style == llama_tool_call_style::Llama32; + auto eagerly_match_any_json = style == common_tool_call_style::Llama32; handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; @@ -639,7 +639,7 @@ llama_tool_call_handler llama_tool_call_handler_init( }); break; } - case llama_tool_call_style::FunctionaryV3Llama3: { + case common_tool_call_style::FunctionaryV3Llama3: { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar auto actual_tools = normalize_tools(tools); @@ -670,7 +670,7 @@ llama_tool_call_handler llama_tool_call_handler_init( // handler.parser = parse_functionary_3_2_tool_calls; break; } - case llama_tool_call_style::FunctionaryV3Llama31: { + case common_tool_call_style::FunctionaryV3Llama31: { // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // TODO: handle tool {type: code_interpreter} as python @@ -700,7 +700,7 @@ llama_tool_call_handler llama_tool_call_handler_init( // handler.parser = parse_functionary_3_2_tool_calls; break; } - case llama_tool_call_style::Hermes2Pro: { + case common_tool_call_style::Hermes2Pro: { // NousResearchHermesPro_2 // (content)?({"name": "foo", "arguments": {"a": 1}})* auto actual_tools = normalize_tools(tools); diff --git a/common/tool-call.h b/common/tool-call.h index b83faa772..022c26f4c 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -7,7 +7,7 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -enum llama_tool_call_style { +enum common_tool_call_style { UnknownToolCallStyle, None, Generic, @@ -21,32 +21,32 @@ enum llama_tool_call_style { FirefunctionV2, }; -struct llama_tool_call { +struct common_tool_call { std::string name; std::string arguments; std::string id; }; -struct llama_tool_calls { +struct common_tool_calls { std::string content; - std::vector tool_calls; + std::vector tool_calls; }; -struct llama_tool_call_handler { +struct common_tool_call_handler { std::string prompt; std::string grammar; std::vector grammar_triggers; std::vector additional_stops; }; -std::string llama_tool_call_style_name(llama_tool_call_style style); +std::string common_tool_call_style_name(common_tool_call_style style); -llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template); +common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template); -llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); +common_tool_calls parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); -llama_tool_call_handler llama_tool_call_handler_init( - llama_tool_call_style style, +common_tool_call_handler common_tool_call_handler_init( + common_tool_call_style style, const common_chat_template & tmpl, bool allow_content, const nlohmann::ordered_json & parallel_tool_calls, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 97430941e..3f2ab6fb3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -118,7 +118,7 @@ struct slot_params { std::string oaicompat_model; std::string oaicompat_cmpl_id; json oaicompat_tools; - llama_tool_call_style oaicompat_tool_call_style = llama_tool_call_style::None; + common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None; json to_json() const { std::vector samplers; @@ -589,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result { std::string oaicompat_model; std::string oaicompat_cmpl_id; json oaicompat_tools; - llama_tool_call_style oaicompat_tool_call_style = llama_tool_call_style::None; + common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None; virtual int get_index() override { return index; @@ -687,10 +687,10 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - llama_tool_calls parsed_tool_calls; + common_tool_calls parsed_tool_calls; json tool_calls; json message_content; - if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) { + if (oaicompat_tool_call_style != common_tool_call_style::None && !oaicompat_tools.is_null()) { parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); if (!parsed_tool_calls.tool_calls.empty()) { finish_reason = "tool_calls"; @@ -3772,7 +3772,7 @@ int main(int argc, char ** argv) { std::function is_connection_closed, httplib::Response & res, oaicompat_type oaicompat, - llama_tool_call_style tool_call_style = llama_tool_call_style::None) { + common_tool_call_style tool_call_style = common_tool_call_style::None) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3979,8 +3979,8 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - auto tool_call_style = llama_tool_call_style_detect(chat_template); - LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str()); + auto tool_call_style = common_tool_call_style_detect(chat_template); + LOG_INF("Tool call style: %s\n", common_tool_call_style_name(tool_call_style).c_str()); json data = oaicompat_completion_params_parse(body, chat_template, tool_call_style, params.use_jinja); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 7641d3410..75e1a876a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -581,7 +581,7 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ const common_chat_template & tmpl, - llama_tool_call_style tool_call_style, + common_tool_call_style tool_call_style, bool use_jinja) { json llama_params; @@ -595,7 +595,7 @@ static json oaicompat_completion_params_parse( throw std::runtime_error("Cannot use tools with stream"); } if (use_jinja) { - if (tool_call_style == llama_tool_call_style::UnknownToolCallStyle) { + if (tool_call_style == common_tool_call_style::UnknownToolCallStyle) { throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); } } else { @@ -634,7 +634,7 @@ static json oaicompat_completion_params_parse( auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json(); llama_params["parallel_tool_calls"] = parallel_tool_calls; - auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]); + auto handler = common_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]); llama_params["prompt"] = handler.prompt; for (const auto & stop : handler.additional_stops) { diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index b25d6c91e..90b7e1296 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -70,7 +70,7 @@ static std::string dump(const json & j) { return minja::Value(j).dump(-1, /* to_json= */ true); } -static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { +static void test_parse_tool_call(common_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { std::cout << "# Testing: " << input << std::endl << std::flush; auto result = parse_tool_calls(style, tools, input); assert_equals(expected_content, result.content); @@ -146,21 +146,21 @@ static void test_parsing() { }} }; - test_parse_tool_call(llama_tool_call_style::Generic, tools, + test_parse_tool_call(common_tool_call_style::Generic, tools, "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}", "", json::array({fooBarCall})); - test_parse_tool_call(llama_tool_call_style::Generic, tools, + test_parse_tool_call(common_tool_call_style::Generic, tools, "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}", "", json::array({fooBarCall})); - test_parse_tool_call(llama_tool_call_style::Hermes2Pro, tools, + test_parse_tool_call(common_tool_call_style::Hermes2Pro, tools, "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", "", json::array({fooBarCall})); - test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools, + test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama3, tools, ">>>python\n{\"code\": \"print('Hello, world!')\"}", "", json {{ @@ -172,7 +172,7 @@ static void test_parsing() { })} }} }}); - test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools, + test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama3, tools, ">>>special_function\n{\"arg1\": 1}\n ", "", json {{ @@ -185,7 +185,7 @@ static void test_parsing() { }} }}); - test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools, + test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama31, tools, "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", "Hello, world!", json { @@ -208,7 +208,7 @@ static void test_parsing() { }} }, }); - test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools, + test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama31, tools, "{ } ", " ", json {{ @@ -219,7 +219,7 @@ static void test_parsing() { }} }}); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "<|python_tag|>this could be anything", "", json {{ @@ -231,7 +231,7 @@ static void test_parsing() { })} }} }}); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "I'm thinking<|python_tag|>", "I'm thinking", json {{ @@ -253,7 +253,7 @@ static void test_parsing() { auto no_function_call = json::array(); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}", "", json::array({{ @@ -263,56 +263,56 @@ static void test_parsing() { {"name", "python"}, }} }})); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json::array({special_function_call})); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json::array({special_function_call})); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json::array({special_function_call})); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json::array({special_function_call})); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json::array({special_function_call})); // No match: function unknown - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", no_function_call); // No match: bad indentation - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", no_function_call); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, + test_parse_tool_call(common_tool_call_style::Llama31, tools, "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", no_function_call); - test_parse_tool_call(llama_tool_call_style::MistralNemo, tools, + test_parse_tool_call(common_tool_call_style::MistralNemo, tools, "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", "Bleh", json::array({special_function_call_with_id})); - test_parse_tool_call(llama_tool_call_style::FirefunctionV2, tools, + test_parse_tool_call(common_tool_call_style::FirefunctionV2, tools, "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]", "Bleh", json::array({special_function_call})); } -static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { +static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) { const common_chat_template tmpl(read_file(template_file), "", ""); - auto tool_call_style = llama_tool_call_style_detect(tmpl); + auto tool_call_style = common_tool_call_style_detect(tmpl); std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; assert_equals(expected, tool_call_style); } @@ -357,7 +357,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { std::cout << "# Testing template: " << template_file << std::endl << std::flush; const common_chat_template tmpl(read_file(template_file), bos_token, eos_token); - auto tool_call_style = llama_tool_call_style_detect(tmpl); + auto tool_call_style = common_tool_call_style_detect(tmpl); auto & tool_calls = tool_calling_message.at("tool_calls"); // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, @@ -367,7 +367,7 @@ static void test_template(const std::string & template_file, const char * bos_to {"content", "Hello, world!"} }; - auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools); + auto handler = common_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools); auto grammar = build_grammar(handler.grammar); if (!grammar) { throw std::runtime_error("Failed to build grammar");