tool-call
: test messages -> template -> grammar -> tool call parser
This commit is contained in:
parent
0ae1112faa
commit
dbda025f87
4 changed files with 191 additions and 63 deletions
|
@ -34,7 +34,9 @@ llama_chat_template::llama_chat_template(const std::string & chat_template, cons
|
||||||
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
||||||
|
|
||||||
_supports_tools = chat_template.find("tools") != std::string::npos;
|
_supports_tools = chat_template.find("tools") != std::string::npos;
|
||||||
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
|
_requires_object_arguments =
|
||||||
|
chat_template.find("tool_call.arguments | items") != std::string::npos
|
||||||
|
|| chat_template.find("{{- tool_call.arguments | tojson }}") != std::string::npos;
|
||||||
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
||||||
|
|
||||||
if (chat_template.find("<tool_call>") != std::string::npos) {
|
if (chat_template.find("<tool_call>") != std::string::npos) {
|
||||||
|
|
|
@ -316,7 +316,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
|
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||||
|
@ -349,7 +349,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_trigger_words.push_back("<tool_call>");
|
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||||
|
|
|
@ -16,6 +16,10 @@
|
||||||
./llama-server --jinja -fa --verbose \
|
./llama-server --jinja -fa --verbose \
|
||||||
-hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf
|
-hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf
|
||||||
|
|
||||||
|
# Llama 3.1 70B
|
||||||
|
./llama-server --jinja -fa --verbose \
|
||||||
|
-hfr lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF -hff Meta-Llama-3.1-70B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
# functionary-small-v3
|
# functionary-small-v3
|
||||||
./llama-server --jinja -fa --verbose \
|
./llama-server --jinja -fa --verbose \
|
||||||
-hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q4_0.gguf \
|
-hfr meetkai/functionary-small-v3.2-GGUF -hff functionary-small-v3.2.Q4_0.gguf \
|
||||||
|
@ -38,10 +42,6 @@
|
||||||
./llama-server --jinja -fa --verbose \
|
./llama-server --jinja -fa --verbose \
|
||||||
-hfr lmstudio-community/Llama-3.2-1B-Instruct-GGUF -hff Llama-3.2-1B-Instruct-Q4_K_M.gguf \
|
-hfr lmstudio-community/Llama-3.2-1B-Instruct-GGUF -hff Llama-3.2-1B-Instruct-Q4_K_M.gguf \
|
||||||
--chat-template-file tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
|
--chat-template-file tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
|
||||||
|
|
||||||
# Llama 3.1 70B (untested)
|
|
||||||
./llama-server --jinja -fa --verbose \
|
|
||||||
-hfr lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF -hff Meta-Llama-3.1-70B-Instruct-Q4_K_M.gguf
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run some tools inside a docker container (check http://localhost:8088/docs once running):
|
- Run some tools inside a docker container (check http://localhost:8088/docs once running):
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
#include "tool-call.h"
|
#include "tool-call.h"
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
#include "unicode.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -30,9 +32,42 @@ static std::string read_file(const std::string &path) {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||||
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||||
*/
|
}
|
||||||
|
|
||||||
|
// TODO: extract to common helper (copied from test-grammar-integration.cpp)
|
||||||
|
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
||||||
|
const auto cpts = unicode_cpts_from_utf8(input);
|
||||||
|
|
||||||
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
|
for (const auto & cpt : cpts) {
|
||||||
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
|
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
||||||
|
|
||||||
|
if (stacks_cur.empty()) {
|
||||||
|
// no stacks means that the grammar failed to match at this point
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & stack : stacks_cur) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
// An empty stack means that the grammar has been completed
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
|
||||||
|
static std::string dump(const json & j) {
|
||||||
|
return minja::Value(j).dump(-1, /* to_json= */ true);
|
||||||
|
}
|
||||||
|
|
||||||
static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
||||||
std::cout << "# Testing: " << input << std::endl << std::flush;
|
std::cout << "# Testing: " << input << std::endl << std::flush;
|
||||||
|
@ -41,16 +76,19 @@ static void test_parse_tool_call(llama_tool_call_style style, const json & tools
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tc : result.tool_calls) {
|
for (const auto & tc : result.tool_calls) {
|
||||||
tool_calls.push_back({
|
tool_calls.push_back({
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", tc.name},
|
{"name", tc.name},
|
||||||
{"arguments", tc.arguments},
|
{"arguments", dump(json::parse(tc.arguments))},
|
||||||
}}
|
}}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
assert_equals(expected_tool_calls.dump(), tool_calls.dump());
|
auto expected = expected_tool_calls.dump();
|
||||||
|
auto actual = tool_calls.dump();
|
||||||
|
assert_equals(expected, actual);
|
||||||
}
|
}
|
||||||
int main() {
|
|
||||||
json tools = json::parse(R"([
|
const json tools = json::parse(R"([
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
|
@ -60,7 +98,7 @@ int main() {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"arg1": {
|
"arg1": {
|
||||||
"type": "string",
|
"type": "integer",
|
||||||
"description": "The arg."
|
"description": "The arg."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -86,6 +124,8 @@ int main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
])");
|
])");
|
||||||
|
|
||||||
|
static void test_parsing() {
|
||||||
json request = {
|
json request = {
|
||||||
{"tools", tools}
|
{"tools", tools}
|
||||||
};
|
};
|
||||||
|
@ -94,11 +134,12 @@ int main() {
|
||||||
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "foo"},
|
{"name", "foo"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({
|
||||||
{"bar", 1}
|
{"bar", 1}
|
||||||
}).dump()}
|
})}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
|
@ -106,22 +147,24 @@ int main() {
|
||||||
">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
|
">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "ipython"},
|
{"name", "ipython"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({
|
||||||
{"code", "print('Hello, world!')"}
|
{"code", "print('Hello, world!')"}
|
||||||
}).dump()}
|
})}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||||
">>>special_function\n{\"arg1\": 1}\n ",
|
">>>special_function\n{\"arg1\": 1}\n ",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "special_function"},
|
{"name", "special_function"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({
|
||||||
{"arg1", 1}
|
{"arg1", 1}
|
||||||
}).dump()}
|
})}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
|
@ -130,19 +173,21 @@ int main() {
|
||||||
"Hello, world!",
|
"Hello, world!",
|
||||||
json {
|
json {
|
||||||
{
|
{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "foo"},
|
{"name", "foo"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({
|
||||||
{"arg1", 1}
|
{"arg1", 1}
|
||||||
}).dump()}
|
})}
|
||||||
}}
|
}}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "bar"},
|
{"name", "bar"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({
|
||||||
{"arg2", 2}
|
{"arg2", 2}
|
||||||
}).dump()}
|
})}
|
||||||
}}
|
}}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -150,6 +195,7 @@ int main() {
|
||||||
"<function=test>{ } </function> ",
|
"<function=test>{ } </function> ",
|
||||||
" ",
|
" ",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "test"},
|
{"name", "test"},
|
||||||
{"arguments", "{}"}
|
{"arguments", "{}"}
|
||||||
|
@ -160,36 +206,116 @@ int main() {
|
||||||
"<|python_tag|>this could be anything",
|
"<|python_tag|>this could be anything",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "ipython"},
|
{"name", "ipython"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({
|
||||||
{"code", "this could be anything"}
|
{"code", "this could be anything"}
|
||||||
}).dump()}
|
})}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"I'm thinking<|python_tag|>",
|
"I'm thinking<|python_tag|>",
|
||||||
"I'm thinking",
|
"I'm thinking",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "ipython"},
|
{"name", "ipython"},
|
||||||
{"arguments", (json {{"code", ""}}).dump()}
|
{"arguments", dump({{"code", ""}})}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "special_function"},
|
{"name", "special_function"},
|
||||||
{"arguments", (json {
|
{"arguments", dump({{"arg1", 1}})}
|
||||||
{"arg1", 1}
|
|
||||||
}).dump()}
|
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
|
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
||||||
|
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
||||||
|
|
||||||
|
// Check full starts with prefix
|
||||||
|
if (full.find(prefix) != 0) {
|
||||||
|
throw std::runtime_error("Full message does not start with prefix");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto delta = full.substr(prefix.size());
|
||||||
|
|
||||||
|
// Strip end tokens
|
||||||
|
for (const auto & end_token : end_tokens) {
|
||||||
|
// rfind to find the last occurrence
|
||||||
|
auto pos = delta.rfind(end_token);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
delta = delta.substr(0, pos);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return delta;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools) {
|
||||||
|
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
||||||
|
const llama_chat_template & tmpl = llama_chat_template(read_file(template_file), bos_token, eos_token);
|
||||||
|
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||||
|
|
||||||
|
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||||
|
// get the diff and try and parse it w/ the grammar.
|
||||||
|
auto user_message = json {
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "Hello, world!"}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto handler = llama_tool_call_handler_init(tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
|
||||||
|
auto grammar = build_grammar(handler.grammar);
|
||||||
|
if (!grammar) {
|
||||||
|
throw std::runtime_error("Failed to build grammar");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
||||||
|
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||||
|
test_parse_tool_call(tmpl.tool_call_style(), tools, full_delta, "", tool_calls);
|
||||||
|
|
||||||
|
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", ""},
|
||||||
|
{"tool_calls", tool_calls}
|
||||||
|
}, tools);
|
||||||
|
if (!match_string(content_less_delta, grammar)) {
|
||||||
|
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_grammars() {
|
||||||
|
auto tool_call_message = json {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", ""},
|
||||||
|
{"tool_calls", json {{
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", "special_function"},
|
||||||
|
{"arguments", "{\"arg1\": 1}"}
|
||||||
|
}}
|
||||||
|
}}}
|
||||||
|
};
|
||||||
|
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
||||||
|
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
test_grammars();
|
||||||
|
test_parsing();
|
||||||
|
|
||||||
std::cout << "[tool-call] All tests passed!" << std::endl;
|
std::cout << "[tool-call] All tests passed!" << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue