beef up test-chat-handler w/ delta expectations
This commit is contained in:
parent
ba10b47ae5
commit
cd63ba435e
3 changed files with 210 additions and 285 deletions
|
@ -541,31 +541,35 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
common_chat_data data;
|
||||
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
});
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (params.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||
} else {
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
});
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (params.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||
} else {
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
|
||||
}, grammar_options);
|
||||
}, grammar_options);
|
||||
data.format = "functionary v3.2 tool calls";
|
||||
} else {
|
||||
data.format = "functionary v3.2 content-only";
|
||||
}
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "functionary v3.2 tool calls";
|
||||
data.parser = [params](const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
@ -763,21 +767,24 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
|
|||
}
|
||||
|
||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
if (params.tools.is_null() || params.tool_choice == "none") {
|
||||
return common_chat_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
|
||||
if (has_tools && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
|
||||
const auto & src = tmpl.source();
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
||||
}
|
||||
|
||||
if (has_tools) {
|
||||
return common_chat_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
||||
}
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
||||
}
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
||||
|
|
|
@ -65,7 +65,13 @@ class chat_template {
|
|||
try_raw_render({
|
||||
{{"role", "user"}, {"content", "Hey"}},
|
||||
}, {
|
||||
{{"name", "some_tool"}, {"parameters", {{"type", "string"}}}},
|
||||
{
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "some_tool"},
|
||||
{"parameters", {{"type", "string"}}},
|
||||
}},
|
||||
},
|
||||
}, false).find("some_tool") != std::string::npos;
|
||||
|
||||
requires_object_arguments_ =
|
||||
|
|
|
@ -10,6 +10,29 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static common_chat_msg msg_from_json(const json & message) {
|
||||
common_chat_msg ret {
|
||||
"assistant",
|
||||
"",
|
||||
{},
|
||||
};
|
||||
if (message.contains("content") && !message.at("content").is_null()) {
|
||||
ret.content = message.at("content").get<std::string>();
|
||||
}
|
||||
auto has_tool_calls = message.contains("tool_calls");
|
||||
if (has_tool_calls) {
|
||||
for (const auto & tc : message.at("tool_calls")) {
|
||||
const auto & arguments = tc.at("function").at("arguments");
|
||||
ret.tool_calls.push_back({
|
||||
tc.at("function").at("name").get<std::string>(),
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
||||
});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
|
@ -139,184 +162,13 @@ const auto code_interpreter_tool = json::parse(R"({
|
|||
const json tools = {special_function_tool, python_tool};
|
||||
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
||||
|
||||
// static void test_parsing() {
|
||||
// json request = {
|
||||
// {"tools", tools}
|
||||
// };
|
||||
struct delta_data {
|
||||
std::string delta;
|
||||
std::string grammar;
|
||||
common_chat_parser parser;
|
||||
};
|
||||
|
||||
// const auto fooBarCall = json {
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "foo"},
|
||||
// {"arguments", dump({
|
||||
// {"bar", 1}
|
||||
// })},
|
||||
// }}
|
||||
// };
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
||||
// "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}",
|
||||
// "",
|
||||
// json::array({fooBarCall}));
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
||||
// "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}",
|
||||
// "",
|
||||
// json::array({fooBarCall}));
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools,
|
||||
// "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
||||
// "",
|
||||
// json::array({fooBarCall}));
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
||||
// ">>>python\n{\"code\": \"print('Hello, world!')\"}",
|
||||
// "",
|
||||
// json {{
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "python"},
|
||||
// {"arguments", dump({
|
||||
// {"code", "print('Hello, world!')"}
|
||||
// })}
|
||||
// }}
|
||||
// }});
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
||||
// ">>>special_function\n{\"arg1\": 1}\n ",
|
||||
// "",
|
||||
// json {{
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "special_function"},
|
||||
// {"arguments", dump({
|
||||
// {"arg1", 1}
|
||||
// })}
|
||||
// }}
|
||||
// }});
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
||||
// "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
||||
// "Hello, world!",
|
||||
// json {
|
||||
// {
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "foo"},
|
||||
// {"arguments", dump({
|
||||
// {"arg1", 1}
|
||||
// })}
|
||||
// }}
|
||||
// },
|
||||
// {
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "bar"},
|
||||
// {"arguments", dump({
|
||||
// {"arg2", 2}
|
||||
// })}
|
||||
// }}
|
||||
// },
|
||||
// });
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
||||
// "<function=test>{ } </function> ",
|
||||
// " ",
|
||||
// json {{
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "test"},
|
||||
// {"arguments", "{}"}
|
||||
// }}
|
||||
// }});
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "<|python_tag|>this could be anything",
|
||||
// "",
|
||||
// json {{
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "python"},
|
||||
// {"arguments", "this could be anything"},
|
||||
// }}
|
||||
// }});
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "I'm thinking<|python_tag|>",
|
||||
// "I'm thinking",
|
||||
// json {{
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"name", "python"},
|
||||
// {"arguments", ""},
|
||||
// }}
|
||||
// }});
|
||||
// auto special_function_call = json {
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"arguments", dump({{"arg1", 1}})},
|
||||
// {"name", "special_function"},
|
||||
// }},
|
||||
// };
|
||||
// auto special_function_call_with_id = json::parse(special_function_call.dump());
|
||||
// special_function_call_with_id["id"] = "123456789";
|
||||
|
||||
// auto no_function_call = json::array();
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}",
|
||||
// "",
|
||||
// json::array({{
|
||||
// {"type", "function"},
|
||||
// {"function", {
|
||||
// {"arguments", dump({{"code", "print('Hey')"}})},
|
||||
// {"name", "python"},
|
||||
// }}
|
||||
// }}));
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||
// "",
|
||||
// json::array({special_function_call}));
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||
// "",
|
||||
// json::array({special_function_call}));
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||
// "",
|
||||
// json::array({special_function_call}));
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||
// "",
|
||||
// json::array({special_function_call}));
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||
// "",
|
||||
// json::array({special_function_call}));
|
||||
|
||||
// // No match: function unknown
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
// no_function_call);
|
||||
// // No match: bad indentation
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
// no_function_call);
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
// no_function_call);
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools,
|
||||
// "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
|
||||
// "Bleh",
|
||||
// json::array({special_function_call_with_id}));
|
||||
|
||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools,
|
||||
// "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
|
||||
// "Bleh",
|
||||
// json::array({special_function_call}));
|
||||
// }
|
||||
|
||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
||||
|
||||
|
@ -325,13 +177,17 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
|||
params.messages = json::array();
|
||||
params.messages.push_back(user_message);
|
||||
params.tools = tools;
|
||||
std::string prefix = common_chat_init(tmpl, params).prompt;
|
||||
auto prefix_data = common_chat_init(tmpl, params);
|
||||
params.messages.push_back(delta_message);
|
||||
params.add_generation_prompt = false;
|
||||
std::string full = common_chat_init(tmpl, params).prompt;
|
||||
auto full_data = common_chat_init(tmpl, params);
|
||||
|
||||
std::string prefix = prefix_data.prompt;
|
||||
std::string full = full_data.prompt;
|
||||
|
||||
// Check full starts with prefix
|
||||
if (full.find(prefix) != 0) {
|
||||
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
|
||||
throw std::runtime_error("Full message does not start with prefix");
|
||||
}
|
||||
|
||||
|
@ -350,27 +206,12 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
|||
break;
|
||||
}
|
||||
}
|
||||
return delta;
|
||||
return {delta, full_data.grammar, full_data.parser};
|
||||
}
|
||||
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) {
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) {
|
||||
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||
common_chat_msg expected_msg {
|
||||
"assistant",
|
||||
"",
|
||||
{},
|
||||
};
|
||||
auto has_tool_calls = test_message.contains("tool_calls");
|
||||
if (has_tool_calls) {
|
||||
for (const auto & tc : test_message.at("tool_calls")) {
|
||||
const auto & arguments = tc.at("function").at("arguments");
|
||||
expected_msg.tool_calls.push_back({
|
||||
tc.at("function").at("name").get<std::string>(),
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
||||
});
|
||||
}
|
||||
}
|
||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||
|
||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||
// get the diff and try and parse it w/ the grammar.
|
||||
|
@ -385,35 +226,37 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
params.parallel_tool_calls = true;
|
||||
params.messages = json {user_message, test_message};
|
||||
params.tools = tools;
|
||||
auto chat_data = common_chat_init(tmpl, params);
|
||||
// fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
|
||||
if (has_tool_calls) {
|
||||
auto grammar = build_grammar(chat_data.grammar);
|
||||
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||
std::cout << "Full delta:\n```\n" << data.delta << "\n```" << std::endl;
|
||||
if (!expected_delta.empty()) {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
|
||||
if (!skip_parser_test) {
|
||||
const auto msg = data.parser(data.delta);
|
||||
assert_msg_equals(expected_msg, msg);
|
||||
}
|
||||
|
||||
if (!expected_msg.tool_calls.empty()) {
|
||||
GGML_ASSERT(!data.grammar.empty());
|
||||
}
|
||||
if (!data.grammar.empty()) {
|
||||
auto grammar = build_grammar(data.grammar);
|
||||
if (!grammar) {
|
||||
throw std::runtime_error("Failed to build grammar");
|
||||
}
|
||||
|
||||
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
|
||||
if (!skip_grammar_test) {
|
||||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||
|
||||
const auto msg = chat_data.parser(full_delta);
|
||||
assert_msg_equals(expected_msg, msg);
|
||||
|
||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||
{"role", "assistant"},
|
||||
{"content", {}},
|
||||
{"tool_calls", test_message.at("tool_calls")}
|
||||
}, tools);
|
||||
if (!match_string(content_less_delta, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
|
||||
if (!match_string(data.delta, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + "\n\nGrammar: " + data.grammar);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_grammars() {
|
||||
static void test_template_output_parsers() {
|
||||
auto text_message = json {
|
||||
{"role", "assistant"},
|
||||
{"content", "Hello, world!"},
|
||||
|
@ -465,6 +308,7 @@ static void test_grammars() {
|
|||
|
||||
common_chat_params tools_params = no_tools_params;
|
||||
tools_params.tools = json::array();
|
||||
tools_params.tools.push_back(special_function_tool);
|
||||
|
||||
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
|
||||
auto data = common_chat_init(tmpl, params);
|
||||
|
@ -477,114 +321,182 @@ static void test_grammars() {
|
|||
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||
"{\n"
|
||||
" \"response\": \"Hello, world!\"\n"
|
||||
"}"));
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
" \"name\": \"special_function\",\n"
|
||||
" \"arguments\": {\n"
|
||||
" \"arg1\": 1\n"
|
||||
" },\n"
|
||||
" \"id\": \"123456789\"\n"
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end|>" };
|
||||
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||
"{\n"
|
||||
" \"response\": \"Hello, world!\"\n"
|
||||
"}"));
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
" \"name\": \"special_function\",\n"
|
||||
" \"arguments\": {\n"
|
||||
" \"arg1\": 1\n"
|
||||
" },\n"
|
||||
" \"id\": \"123456789\"\n"
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "</s>" };
|
||||
|
||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||
/* skip_grammar_test= */ true);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
// test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"all\n"
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"special_function\n"
|
||||
"{\"arg1\": 1}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|>");
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// test_parsing();
|
||||
test_grammars();
|
||||
test_template_output_parsers();
|
||||
|
||||
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue