beef up test-chat-handler w/ delta expectations

This commit is contained in:
ochafik 2025-01-28 14:40:23 +00:00
parent ba10b47ae5
commit cd63ba435e
3 changed files with 210 additions and 285 deletions

View file

@ -541,31 +541,35 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
data.grammar_triggers.push_back({name, /* .at_start = */ true});
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
});
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (params.parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else {
builder.add_rule("root", first_rule);
}
if (!params.tools.is_null() && !params.tools.empty()) {
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
data.grammar_triggers.push_back({name, /* .at_start = */ true});
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
});
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (params.parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else {
builder.add_rule("root", first_rule);
}
}, grammar_options);
}, grammar_options);
data.format = "functionary v3.2 tool calls";
} else {
data.format = "functionary v3.2 content-only";
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "functionary v3.2 tool calls";
data.parser = [params](const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
@ -763,21 +767,24 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
}
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
if (params.tools.is_null() || params.tool_choice == "none") {
return common_chat_init_without_tools(tmpl, params);
}
if (!params.grammar.empty()) {
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
if (has_tools && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
const auto & src = tmpl.source();
if (src.find(">>>all") != std::string::npos) {
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
}
if (has_tools) {
return common_chat_init_without_tools(tmpl, params);
}
if (src.find("<tool_call>") != std::string::npos) {
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
}
if (src.find(">>>all") != std::string::npos) {
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
}
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);

View file

@ -65,7 +65,13 @@ class chat_template {
try_raw_render({
{{"role", "user"}, {"content", "Hey"}},
}, {
{{"name", "some_tool"}, {"parameters", {{"type", "string"}}}},
{
{"type", "function"},
{"function", {
{"name", "some_tool"},
{"parameters", {{"type", "string"}}},
}},
},
}, false).find("some_tool") != std::string::npos;
requires_object_arguments_ =

View file

@ -10,6 +10,29 @@
using json = nlohmann::ordered_json;
static common_chat_msg msg_from_json(const json & message) {
common_chat_msg ret {
"assistant",
"",
{},
};
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content").get<std::string>();
}
auto has_tool_calls = message.contains("tool_calls");
if (has_tool_calls) {
for (const auto & tc : message.at("tool_calls")) {
const auto & arguments = tc.at("function").at("arguments");
ret.tool_calls.push_back({
tc.at("function").at("name").get<std::string>(),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tc.contains("id") ? tc.at("id").get<std::string>() : "",
});
}
}
return ret;
}
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
@ -139,184 +162,13 @@ const auto code_interpreter_tool = json::parse(R"({
const json tools = {special_function_tool, python_tool};
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
// static void test_parsing() {
// json request = {
// {"tools", tools}
// };
struct delta_data {
std::string delta;
std::string grammar;
common_chat_parser parser;
};
// const auto fooBarCall = json {
// {"type", "function"},
// {"function", {
// {"name", "foo"},
// {"arguments", dump({
// {"bar", 1}
// })},
// }}
// };
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
// "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}",
// "",
// json::array({fooBarCall}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
// "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}",
// "",
// json::array({fooBarCall}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools,
// "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
// "",
// json::array({fooBarCall}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
// ">>>python\n{\"code\": \"print('Hello, world!')\"}",
// "",
// json {{
// {"type", "function"},
// {"function", {
// {"name", "python"},
// {"arguments", dump({
// {"code", "print('Hello, world!')"}
// })}
// }}
// }});
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
// ">>>special_function\n{\"arg1\": 1}\n ",
// "",
// json {{
// {"type", "function"},
// {"function", {
// {"name", "special_function"},
// {"arguments", dump({
// {"arg1", 1}
// })}
// }}
// }});
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
// "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
// "Hello, world!",
// json {
// {
// {"type", "function"},
// {"function", {
// {"name", "foo"},
// {"arguments", dump({
// {"arg1", 1}
// })}
// }}
// },
// {
// {"type", "function"},
// {"function", {
// {"name", "bar"},
// {"arguments", dump({
// {"arg2", 2}
// })}
// }}
// },
// });
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
// "<function=test>{ } </function> ",
// " ",
// json {{
// {"type", "function"},
// {"function", {
// {"name", "test"},
// {"arguments", "{}"}
// }}
// }});
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "<|python_tag|>this could be anything",
// "",
// json {{
// {"type", "function"},
// {"function", {
// {"name", "python"},
// {"arguments", "this could be anything"},
// }}
// }});
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "I'm thinking<|python_tag|>",
// "I'm thinking",
// json {{
// {"type", "function"},
// {"function", {
// {"name", "python"},
// {"arguments", ""},
// }}
// }});
// auto special_function_call = json {
// {"type", "function"},
// {"function", {
// {"arguments", dump({{"arg1", 1}})},
// {"name", "special_function"},
// }},
// };
// auto special_function_call_with_id = json::parse(special_function_call.dump());
// special_function_call_with_id["id"] = "123456789";
// auto no_function_call = json::array();
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}",
// "",
// json::array({{
// {"type", "function"},
// {"function", {
// {"arguments", dump({{"code", "print('Hey')"}})},
// {"name", "python"},
// }}
// }}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
// "",
// json::array({special_function_call}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
// "",
// json::array({special_function_call}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
// "",
// json::array({special_function_call}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
// "",
// json::array({special_function_call}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
// "",
// json::array({special_function_call}));
// // No match: function unknown
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
// no_function_call);
// // No match: bad indentation
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
// no_function_call);
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
// no_function_call);
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools,
// "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
// "Bleh",
// json::array({special_function_call_with_id}));
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools,
// "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
// "Bleh",
// json::array({special_function_call}));
// }
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
@ -325,13 +177,17 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
params.messages = json::array();
params.messages.push_back(user_message);
params.tools = tools;
std::string prefix = common_chat_init(tmpl, params).prompt;
auto prefix_data = common_chat_init(tmpl, params);
params.messages.push_back(delta_message);
params.add_generation_prompt = false;
std::string full = common_chat_init(tmpl, params).prompt;
auto full_data = common_chat_init(tmpl, params);
std::string prefix = prefix_data.prompt;
std::string full = full_data.prompt;
// Check full starts with prefix
if (full.find(prefix) != 0) {
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
throw std::runtime_error("Full message does not start with prefix");
}
@ -350,27 +206,12 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
break;
}
}
return delta;
return {delta, full_data.grammar, full_data.parser};
}
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) {
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) {
// auto tool_call_style = common_tool_call_style_detect(tmpl);
common_chat_msg expected_msg {
"assistant",
"",
{},
};
auto has_tool_calls = test_message.contains("tool_calls");
if (has_tool_calls) {
for (const auto & tc : test_message.at("tool_calls")) {
const auto & arguments = tc.at("function").at("arguments");
expected_msg.tool_calls.push_back({
tc.at("function").at("name").get<std::string>(),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tc.contains("id") ? tc.at("id").get<std::string>() : "",
});
}
}
common_chat_msg expected_msg = msg_from_json(test_message);
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
// get the diff and try and parse it w/ the grammar.
@ -385,35 +226,37 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
params.parallel_tool_calls = true;
params.messages = json {user_message, test_message};
params.tools = tools;
auto chat_data = common_chat_init(tmpl, params);
// fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
if (has_tool_calls) {
auto grammar = build_grammar(chat_data.grammar);
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
std::cout << "Full delta:\n```\n" << data.delta << "\n```" << std::endl;
if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta);
}
if (!skip_parser_test) {
const auto msg = data.parser(data.delta);
assert_msg_equals(expected_msg, msg);
}
if (!expected_msg.tool_calls.empty()) {
GGML_ASSERT(!data.grammar.empty());
}
if (!data.grammar.empty()) {
auto grammar = build_grammar(data.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
}
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
if (!skip_grammar_test) {
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
const auto msg = chat_data.parser(full_delta);
assert_msg_equals(expected_msg, msg);
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
{"role", "assistant"},
{"content", {}},
{"tool_calls", test_message.at("tool_calls")}
}, tools);
if (!match_string(content_less_delta, grammar.get())) {
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
if (!match_string(data.delta, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + "\n\nGrammar: " + data.grammar);
}
}
}
}
}
static void test_grammars() {
static void test_template_output_parsers() {
auto text_message = json {
{"role", "assistant"},
{"content", "Hello, world!"},
@ -465,6 +308,7 @@ static void test_grammars() {
common_chat_params tools_params = no_tools_params;
tools_params.tools = json::array();
tools_params.tools.push_back(special_function_tool);
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
auto data = common_chat_init(tmpl, params);
@ -477,114 +321,182 @@ static void test_grammars() {
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
"{\n"
" \"response\": \"Hello, world!\"\n"
"}"));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|end|>" };
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
"{\n"
" \"response\": \"Hello, world!\"\n"
"}"));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "</s>" };
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, python_tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
// test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
test_template(tmpl, end_tokens, python_tool_call_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
"<|python_tag|>python.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"all\n"
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"special_function\n"
"{\"arg1\": 1}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eot_id|>" };
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end▁of▁sentence>" };
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n"
"{\"arg1\": 1}\n"
"```<tool▁call▁end>");
}
}
int main() {
// test_parsing();
test_grammars();
test_template_output_parsers();
std::cout << "\n[tool-call] All tests passed!" << std::endl;
return 0;