From dbda025f87234149b9bf34fb917875cfc81f2c34 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 28 Sep 2024 22:32:47 +0100 Subject: [PATCH] `tool-call`: test messages -> template -> grammar -> tool call parser --- common/chat-template.cpp | 4 +- common/tool-call.cpp | 4 +- examples/agent/README.md | 8 +- tests/test-tool-call.cpp | 238 ++++++++++++++++++++++++++++++--------- 4 files changed, 191 insertions(+), 63 deletions(-) diff --git a/common/chat-template.cpp b/common/chat-template.cpp index ed2340f45..7234e524c 100644 --- a/common/chat-template.cpp +++ b/common/chat-template.cpp @@ -34,7 +34,9 @@ llama_chat_template::llama_chat_template(const std::string & chat_template, cons : _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) { _supports_tools = chat_template.find("tools") != std::string::npos; - _requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos; + _requires_object_arguments = + chat_template.find("tool_call.arguments | items") != std::string::npos + || chat_template.find("{{- tool_call.arguments | tojson }}") != std::string::npos; _supports_system_role = chat_template.find("System role not supported") == std::string::npos; if (chat_template.find("") != std::string::npos) { diff --git a/common/tool-call.cpp b/common/tool-call.cpp index b0f4698e7..55d5cae59 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -316,7 +316,7 @@ llama_tool_call_handler llama_tool_call_handler_init( handler.grammar_trigger_words.push_back("<|python_tag|>"); } } else { - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\"")); + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); } } auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; @@ -349,7 +349,7 @@ llama_tool_call_handler llama_tool_call_handler_init( })); } - auto tool_call = "\"\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; + auto tool_call = "\"\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_trigger_words.push_back(""); diff --git a/examples/agent/README.md b/examples/agent/README.md index 45b159815..8845819f0 100644 --- a/examples/agent/README.md +++ b/examples/agent/README.md @@ -16,6 +16,10 @@ ./llama-server --jinja -fa --verbose \ -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf + # Llama 3.1 70B + ./llama-server --jinja -fa --verbose \ + -hfr lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF -hff Meta-Llama-3.1-70B-Instruct-Q4_K_M.gguf + # functionary-small-v3 ./llama-server --jinja -fa --verbose \ -hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q4_0.gguf \ @@ -38,10 +42,6 @@ ./llama-server --jinja -fa --verbose \ -hfr lmstudio-community/Llama-3.2-1B-Instruct-GGUF -hff Llama-3.2-1B-Instruct-Q4_K_M.gguf \ --chat-template-file tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja - - # Llama 3.1 70B (untested) - ./llama-server --jinja -fa --verbose \ - -hfr lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF -hff Meta-Llama-3.1-70B-Instruct-Q4_K_M.gguf ``` - Run some tools inside a docker container (check http://localhost:8088/docs once running): diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 717758432..b3a824db7 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -1,4 +1,6 @@ #include "tool-call.h" +#include "llama-grammar.h" +#include "unicode.h" #include #include @@ -30,9 +32,42 @@ static std::string read_file(const std::string &path) { return out; } -/* - cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call -*/ +static llama_grammar * build_grammar(const std::string & grammar_str) { + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); +} + +// TODO: extract to common helper (copied from test-grammar-integration.cpp) +static bool match_string(const std::string & input, llama_grammar * grammar) { + const auto cpts = unicode_cpts_from_utf8(input); + + const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + + for (const auto & cpt : cpts) { + const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy + + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); + + if (stacks_cur.empty()) { + // no stacks means that the grammar failed to match at this point + return false; + } + } + + for (const auto & stack : stacks_cur) { + if (stack.empty()) { + // An empty stack means that the grammar has been completed + return true; + } + } + + return false; +} + +// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`. +static std::string dump(const json & j) { + return minja::Value(j).dump(-1, /* to_json= */ true); +} static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { std::cout << "# Testing: " << input << std::endl << std::flush; @@ -41,51 +76,56 @@ static void test_parse_tool_call(llama_tool_call_style style, const json & tools auto tool_calls = json::array(); for (const auto & tc : result.tool_calls) { tool_calls.push_back({ + {"type", "function"}, {"function", { {"name", tc.name}, - {"arguments", tc.arguments}, + {"arguments", dump(json::parse(tc.arguments))}, }} }); } - assert_equals(expected_tool_calls.dump(), tool_calls.dump()); + auto expected = expected_tool_calls.dump(); + auto actual = tool_calls.dump(); + assert_equals(expected, actual); } -int main() { - json tools = json::parse(R"([ - { - "type": "function", - "function": { - "name": "special_function", - "description": "I'm special", - "parameters": { - "type": "object", - "properties": { - "arg1": { - "type": "string", - "description": "The arg." - } - }, - "required": ["arg1"] + +const json tools = json::parse(R"([ + { + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." } - } - }, - { - "type": "function", - "function": { - "name": "ipython", - "description": "a python interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The code." - } - }, - "required": ["code"] - } - } + }, + "required": ["arg1"] } - ])"); + } + }, + { + "type": "function", + "function": { + "name": "ipython", + "description": "a python interpreter", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code." + } + }, + "required": ["code"] + } + } + } +])"); + +static void test_parsing() { json request = { {"tools", tools} }; @@ -94,11 +134,12 @@ int main() { "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", "", json {{ + {"type", "function"}, {"function", { {"name", "foo"}, - {"arguments", (json { + {"arguments", dump({ {"bar", 1} - }).dump()} + })} }} }}); @@ -106,22 +147,24 @@ int main() { ">>>ipython\n{\"code\": \"print('Hello, world!')\"}", "", json {{ + {"type", "function"}, {"function", { {"name", "ipython"}, - {"arguments", (json { + {"arguments", dump({ {"code", "print('Hello, world!')"} - }).dump()} + })} }} }}); test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools, ">>>special_function\n{\"arg1\": 1}\n ", "", json {{ + {"type", "function"}, {"function", { {"name", "special_function"}, - {"arguments", (json { + {"arguments", dump({ {"arg1", 1} - }).dump()} + })} }} }}); @@ -130,19 +173,21 @@ int main() { "Hello, world!", json { { + {"type", "function"}, {"function", { {"name", "foo"}, - {"arguments", (json { + {"arguments", dump({ {"arg1", 1} - }).dump()} + })} }} }, { + {"type", "function"}, {"function", { {"name", "bar"}, - {"arguments", (json { + {"arguments", dump({ {"arg2", 2} - }).dump()} + })} }} }, }); @@ -150,6 +195,7 @@ int main() { "{ } ", " ", json {{ + {"type", "function"}, {"function", { {"name", "test"}, {"arguments", "{}"} @@ -160,36 +206,116 @@ int main() { "<|python_tag|>this could be anything", "", json {{ + {"type", "function"}, {"function", { {"name", "ipython"}, - {"arguments", (json { + {"arguments", dump({ {"code", "this could be anything"} - }).dump()} + })} }} }}); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "I'm thinking<|python_tag|>", "I'm thinking", json {{ + {"type", "function"}, {"function", { {"name", "ipython"}, - {"arguments", (json {{"code", ""}}).dump()} + {"arguments", dump({{"code", ""}})} }} }}); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json {{ + {"type", "function"}, {"function", { {"name", "special_function"}, - {"arguments", (json { - {"arg1", 1} - }).dump()} + {"arguments", dump({{"arg1", 1}})} }} }}); test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); +} + +static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { + auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); + auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); + + // Check full starts with prefix + if (full.find(prefix) != 0) { + throw std::runtime_error("Full message does not start with prefix"); + } + + auto delta = full.substr(prefix.size()); + + // Strip end tokens + for (const auto & end_token : end_tokens) { + // rfind to find the last occurrence + auto pos = delta.rfind(end_token); + if (pos != std::string::npos) { + delta = delta.substr(0, pos); + break; + } + } + return delta; +} + +static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools) { + std::cout << "# Testing template: " << template_file << std::endl << std::flush; + const llama_chat_template & tmpl = llama_chat_template(read_file(template_file), bos_token, eos_token); + auto & tool_calls = tool_calling_message.at("tool_calls"); + + // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, + // get the diff and try and parse it w/ the grammar. + auto user_message = json { + {"role", "user"}, + {"content", "Hello, world!"} + }; + + auto handler = llama_tool_call_handler_init(tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools); + auto grammar = build_grammar(handler.grammar); + if (!grammar) { + throw std::runtime_error("Failed to build grammar"); + } + + auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools); + std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; + test_parse_tool_call(tmpl.tool_call_style(), tools, full_delta, "", tool_calls); + + auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { + {"role", "assistant"}, + {"content", ""}, + {"tool_calls", tool_calls} + }, tools); + if (!match_string(content_less_delta, grammar)) { + throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar); + } +} + +static void test_grammars() { + auto tool_call_message = json { + {"role", "assistant"}, + {"content", ""}, + {"tool_calls", json {{ + {"type", "function"}, + {"function", { + {"name", "special_function"}, + {"arguments", "{\"arg1\": 1}"} + }} + }}} + }; + test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "", "", { "<|im_end|>" }, tool_call_message, tools); + test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); +} + +int main() { + test_grammars(); + test_parsing(); std::cout << "[tool-call] All tests passed!" << std::endl; return 0;