From 3e12b9b38ecc19f2f16081e7d7576af696cac4ad Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 23 Oct 2024 02:30:31 +0100 Subject: [PATCH] `tool-calls`: basic Nemo support, default parallel to true if template mentions tool_call_id --- common/chat-template.hpp | 3 + common/tool-call.cpp | 184 ++++++++++++++++++++++++------ common/tool-call.h | 5 +- examples/agent/README.md | 15 +-- tests/chat/contexts/tool_use.json | 9 +- tests/test-tool-call.cpp | 88 +++++++++----- 6 files changed, 227 insertions(+), 77 deletions(-) diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 47ec0d402..7e3932174 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -26,6 +26,7 @@ class chat_template { // Most other templates (and OpenAI's API) expect the arguments object to be stringified. bool _requires_object_arguments = false; bool _supports_system_role = true; + bool _supports_parallel_tool_calls = false; std::string _source; std::string _bos_token; std::string _eos_token; @@ -40,6 +41,7 @@ class chat_template { source.find("tool_call.arguments | items") != std::string::npos || source.find("tool_call.arguments | tojson") != std::string::npos; _supports_system_role = source.find("System role not supported") == std::string::npos; + _supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos; _template_root = minja::Parser::parse(_source, { /* .trim_blocks = */ true, @@ -50,6 +52,7 @@ class chat_template { const std::string & source() const { return _source; } bool supports_tools() const { return _supports_tools; } + bool supports_parallel_tool_calls() const { return _supports_parallel_tool_calls; } std::string apply( const nlohmann::ordered_json & messages, diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 9c1ff0036..29e9b69b9 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -14,6 +14,8 @@ using json = nlohmann::ordered_json; std::string llama_tool_call_style_name(llama_tool_call_style style) { switch (style) { + case llama_tool_call_style::None: + return "None"; case llama_tool_call_style::Generic: return "Generic"; case llama_tool_call_style::Llama31: @@ -28,6 +30,8 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) { return "Hermes2Pro"; case llama_tool_call_style::CommandRPlus: return "CommandRPlus"; + case llama_tool_call_style::MistralNemo: + return "MistralNemo"; default: return "Unknown"; } @@ -51,6 +55,8 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & } } else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { return CommandRPlus; + } else if (src.find("[TOOL_CALLS]") != std::string::npos) { + return MistralNemo; } else { return Generic; } @@ -146,7 +152,7 @@ static llama_tool_calls parse_json_tool_calls(const json & tools, const std::str throw std::runtime_error("Malformed input, missing closing pattern"); } it = match.suffix().first; - result.tool_calls.push_back({name, arguments.dump()}); + result.tool_calls.push_back({name, arguments.dump(), /* id= */ ""}); } return result; } @@ -176,6 +182,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { result.tool_calls.push_back({ call["name"], call["arguments"].dump(), + /* id= */ "", }); rit = {it, end, middle_pattern}; if (rit != rend) { @@ -241,12 +248,14 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) { result.tool_calls.push_back({ tool_call["name"], tool_call["arguments"].dump(), + /* id= */ "", }); } } else if (data.contains("tool_call")) { result.tool_calls.push_back({ data["tool_call"]["name"], data["tool_call"]["arguments"].dump(), + /* id= */ "", }); } else if (data.contains("response")) { const auto & response = data["response"]; @@ -255,8 +264,38 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) { return result; } +static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) { + auto content_end = input.find("[TOOL_CALLS]"); + size_t tc_start = std::string::npos; + if (content_end != std::string::npos) { + tc_start = content_end + 12; + } else { + // Somehow not getting [TOOL_CALLS] in the output. Oh well, just do without it. + content_end = input.find("[{\""); + if (content_end == std::string::npos || content_end > 0) { + return {input, {}}; + } + tc_start = content_end; + } + llama_tool_calls result; + result.content = input.substr(0, content_end); + auto tool_calls = json::parse(input.substr(tc_start)); + for (const auto & tool_call : tool_calls) { + const auto & arguments = tool_call["arguments"]; + result.tool_calls.push_back({ + tool_call["name"], + arguments.is_string() ? arguments.get() : arguments.dump(), + tool_call["id"], + }); + } + return result; +} + llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { + fprintf(stderr, "# parse_tool_calls:\n\n%s\n\n", input.c_str()); switch (style) { + case llama_tool_call_style::None: + return {input, {}}; case llama_tool_call_style::Generic: return parse_generic_tool_calls(input); case llama_tool_call_style::Llama31: @@ -269,23 +308,43 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool return parse_functionary_v3_llama_3_1_tool_calls(tools, input); case llama_tool_call_style::Hermes2Pro: return parse_hermes_tool_calls(input); + case llama_tool_call_style::MistralNemo: + return parse_mistral_nemo_tool_calls(input); default: throw std::runtime_error("Unsupported tool call style"); } } +static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { + json messages_with_system = messages; + + if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + messages_with_system.at(0).at("content") += ("\n" + system_prompt); + } else { + messages_with_system.insert(messages_with_system.begin(), json { + {"role", "system"}, + {"content", system_prompt}, + }); + } + return messages_with_system; +} + llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, const minja::chat_template & tmpl, bool allow_content, - bool parallel_tool_calls, + const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, const nlohmann::ordered_json & json_schema) { llama_tool_call_handler handler; + auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get(); switch (style) { + case llama_tool_call_style::None: + handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); + break; case llama_tool_call_style::Generic: { auto tool_call_schemas = json::array(); for (const auto & tool : tools) { @@ -307,43 +366,98 @@ llama_tool_call_handler llama_tool_call_handler_init( {"required", json::array({"name", "arguments"})}, }); } - const auto tool_call = json {{"anyOf", tool_call_schemas}}; - const auto schema = json { - {"anyOf", json::array({ - parallel_tool_calls - ? json { - {"type", "object"}, - {"properties", { - {"tool_calls", { - {"type", "array"}, - {"items", tool_call} - }}, - }}, - {"required", json::array({"tool_calls"})}, - } - : json { - {"type", "object"}, - {"properties", { - {"tool_call", tool_call}, - }}, - {"required", json::array({"tool_call"})}, - }, - { + const auto tool_call = + parallel + ? json { {"type", "object"}, {"properties", { - {"response", json_schema.is_null() - ? json {{"type", "string"}} - : json_schema - }, + {"tool_calls", { + {"type", "array"}, + {"items", json {{"anyOf", tool_call_schemas}}} + }}, }}, - }, - })} - }; + {"required", json::array({"tool_calls"})}, + } + : json { + {"type", "object"}, + {"properties", { + {"tool_call", json {{"anyOf", tool_call_schemas}}}, + }}, + {"required", json::array({"tool_call"})}, + }; + const auto schema = + allow_content + ? json { + {"anyOf", json::array({ + tool_call, + { + {"type", "object"}, + {"properties", { + {"response", json_schema.is_null() + ? json {{"type", "string"}} + : json_schema + }, + }}, + }, + })} + } + : tool_call; handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { builder.add_schema("", schema); }); // TODO: add schema to system prompt. - handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); + auto tweaked_messages = add_system( + messages, + "Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```"); + handler.prompt = tmpl.apply(tweaked_messages, tools, /* add_generation_prompt= */ true); + break; + } + case llama_tool_call_style::MistralNemo: { + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + auto schemas = json::array(); + for (const auto & tool : tools) { + if (tool["type"] != "function") { + continue; + } + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto schema = json { + {"type", "object"}, + {"properties", { + // Important note: the model is probably trained to take a JSON stringified arguments value. + // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. + {"arguments", parameters}, + {"name", { + {"type", "string"}, + {"const", name}, + }}, + {"id", { + {"type", "string"}, + // Nemo's template expects a 9-character alphanumeric ID. + {"pattern", "^[a-zA-Z0-9]{9}$"}, + }}, + }}, + {"required", json::array({"arguments", "id", "name"})}, + }; + schemas.push_back(schema); + } + auto schema = json { + {"type", "array"}, + {"items", json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!parallel) { + schema["maxItems"] = 1; + } + builder.add_schema("", schema); + }); + if (allow_content) { + handler.grammar_trigger_words.push_back("[TOOL_CALLS]"); + handler.grammar_trigger_words.push_back("[{\""); + } + auto tweaked_messages = add_system(messages, "Prefix any tool calls with [TOOL_CALLS]"); + handler.prompt = tmpl.apply(tweaked_messages, tools, /* add_generation_prompt= */ true); break; } case llama_tool_call_style::Llama31: @@ -427,7 +541,7 @@ llama_tool_call_handler llama_tool_call_handler_init( } } auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space"; - if (parallel_tool_calls) { + if (parallel) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space"; builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); } else { @@ -459,7 +573,7 @@ llama_tool_call_handler llama_tool_call_handler_init( } } auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; - builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_trigger_words.push_back("\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; - builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_trigger_words.push_back(""); } diff --git a/common/tool-call.h b/common/tool-call.h index 94f5a04ae..6d1265460 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -9,6 +9,7 @@ enum llama_tool_call_style { UnknownToolCallStyle, + None, Generic, Llama31, Llama32, @@ -16,11 +17,13 @@ enum llama_tool_call_style { FunctionaryV3Llama31, Hermes2Pro, CommandRPlus, + MistralNemo, }; struct llama_tool_call { std::string name; std::string arguments; + std::string id; }; struct llama_tool_calls { @@ -45,7 +48,7 @@ llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, const minja::chat_template & tmpl, bool allow_content, - bool parallel_tool_calls, + const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, const nlohmann::ordered_json & json_schema = {}); diff --git a/examples/agent/README.md b/examples/agent/README.md index aa04f0a96..2edcc8473 100644 --- a/examples/agent/README.md +++ b/examples/agent/README.md @@ -7,6 +7,11 @@ ```bash make -j LLAMA_CURL=1 llama-server + # Mistral NeMo + ./llama-server --jinja -fa --verbose \ + -hfr bartowski/Mistral-Nemo-Instruct-2407-GGUF -hff Mistral-Nemo-Instruct-2407-Q8_0.gguf \ + --chat-template "$( python scripts/get_hf_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 )" + # Nous Hermes 2 Pro Llama 3 8B ./llama-server --jinja -fa --verbose \ -hfr NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF -hff Hermes-2-Pro-Llama-3-8B-Q8_0.gguf \ @@ -27,7 +32,7 @@ # Llama 3.2 3B (poor adherence) ./llama-server --jinja -fa --verbose \ - -hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K_L.gguf \ + -hfr lmstudio-community/Llama-3.2-3B-Instruct-GGUF -hff Llama-3.2-3B-Instruct-Q6_K.gguf \ --chat-template "$( python scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct )" # Llama 3.2 1B (very poor adherence) @@ -39,12 +44,8 @@ - Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running): ```bash - export BRAVE_SEARCH_API_KEY=... # https://api.search.brave.com/ - # Shorthand: ./examples/agent/serve_tools_inside_docker.sh - docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \ - --env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \ - --rm -it ghcr.io/astral-sh/uv:python3.12-alpine \ - uv run serve_tools.py --port 8088 + export BRAVE_SEARCH_API_KEY=... # Get one at https://api.search.brave.com/ + ./examples/agent/serve_tools_inside_docker.sh ``` > [!WARNING] diff --git a/tests/chat/contexts/tool_use.json b/tests/chat/contexts/tool_use.json index 6acaef313..2797ac5c7 100644 --- a/tests/chat/contexts/tool_use.json +++ b/tests/chat/contexts/tool_use.json @@ -9,7 +9,7 @@ "content": "", "tool_calls": [ { - "id": "call_1", + "id": "call_1___", "type": "function", "function": { "arguments": "{\"code\": \"print('Hello, World!')\"}", @@ -20,6 +20,7 @@ }, { "role": "tool", + "tool_call_id": "call_1___", "name": "ipython", "content": "{\"stdout\": \"Hello, World!\"}" }, @@ -36,7 +37,7 @@ "content": "", "tool_calls": [ { - "id": "call_2", + "id": "call_2___", "type": "function", "function": { "arguments": "{\"condition\":true}", @@ -47,6 +48,7 @@ }, { "role": "tool", + "tool_call_id": "call_2___", "name": "test", "content": "true" }, @@ -63,7 +65,7 @@ "content": "", "tool_calls": [ { - "id": "call_3", + "id": "call_3___", "type": "function", "function": { "arguments": "{\"query\": \"what is truth anyway am I right?\"}", @@ -74,6 +76,7 @@ }, { "role": "tool", + "tool_call_id": "call_3___", "name": "brave_search", "content": "{\"title\":\"Truth: don't ask the web, ask an LLM instead!\",\"url\":\"https://en.wikipedia.org/wiki/Truth\"}" }, diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 5e47464ce..cee5989d3 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -79,16 +79,21 @@ static void test_parse_tool_call(llama_tool_call_style style, const json & tools assert_equals(expected_content, result.content); auto tool_calls = json::array(); for (const auto & tc : result.tool_calls) { - tool_calls.push_back({ - {"type", "function"}, - {"function", { - {"name", tc.name}, - {"arguments", dump(json::parse(tc.arguments))}, - }} - }); + auto tool_call = json { + {"type", "function"}, + {"function", { + {"arguments", dump(json::parse(tc.arguments))}, + {"name", tc.name}, + }}, + }; + if (!tc.id.empty()) { + tool_call["id"] = tc.id; + } + tool_calls.push_back(tool_call); } - auto expected = expected_tool_calls.dump(); - auto actual = tool_calls.dump(); + // Reparse / dump w/ non-ordered JSON variant. + auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump(); + auto actual = nlohmann::json::parse(tool_calls.dump()).dump(); assert_equals(expected, actual); } @@ -140,7 +145,7 @@ static void test_parsing() { {"name", "foo"}, {"arguments", dump({ {"bar", 1} - })} + })}, }} }; @@ -239,35 +244,38 @@ static void test_parsing() { {"arguments", dump({{"code", ""}})} }} }}); - auto just_special_function_call = json {{ + auto special_function_call = json { {"type", "function"}, {"function", { + {"arguments", dump({{"arg1", 1}})}, {"name", "special_function"}, - {"arguments", dump({{"arg1", 1}})} - }} - }}; + }}, + }; + auto special_function_call_with_id = json::parse(special_function_call.dump()); + special_function_call_with_id["id"] = "123456789"; + auto no_function_call = json::array(); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", - just_special_function_call); + json::array({special_function_call})); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", - just_special_function_call); + json::array({special_function_call})); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", - just_special_function_call); + json::array({special_function_call})); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", - just_special_function_call); + json::array({special_function_call})); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", - just_special_function_call); + json::array({special_function_call})); // No match: function unknown test_parse_tool_call(llama_tool_call_style::Llama31, tools, @@ -283,6 +291,15 @@ static void test_parsing() { "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", no_function_call); + + test_parse_tool_call(llama_tool_call_style::MistralNemo, tools, + "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", + "Bleh", + json::array({special_function_call_with_id})); + test_parse_tool_call(llama_tool_call_style::MistralNemo, tools, + "[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", + "", + json::array({special_function_call_with_id})); } static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { @@ -298,6 +315,8 @@ static void test_tool_call_style_detection() { test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus); + test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", MistralNemo); + test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic); } static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { @@ -323,7 +342,7 @@ static std::string get_message_prompt_delta(const minja::chat_template & tmpl, c return delta; } -static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools) { +static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { std::cout << "# Testing template: " << template_file << std::endl << std::flush; const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token); auto tool_call_style = llama_tool_call_style_detect(tmpl); @@ -342,17 +361,19 @@ static void test_template(const std::string & template_file, const char * bos_to throw std::runtime_error("Failed to build grammar"); } - auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools); - std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; - test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls); + if (!skip_grammar_test) { + auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools); + std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; + test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls); - auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { - {"role", "assistant"}, - {"content", ""}, - {"tool_calls", tool_calls} - }, tools); - if (!match_string(content_less_delta, grammar.get())) { - throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar); + auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { + {"role", "assistant"}, + {"content", ""}, + {"tool_calls", tool_calls} + }, tools); + if (!match_string(content_less_delta, grammar.get())) { + throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar); + } } } @@ -365,9 +386,14 @@ static void test_grammars() { {"function", { {"name", "special_function"}, {"arguments", "{\"arg1\": 1}"} - }} + }}, }}} }; + auto tool_call_message_with_id = json::parse(tool_call_message.dump()); + tool_call_message_with_id["tool_calls"][0]["id"] = "123456789"; + + test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "", "", { "" }, tool_call_message_with_id, tools, + /* skip_grammar_test= */ true); test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "", "", { "<|im_end|>" }, tool_call_message, tools); test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);