beef up test-chat-handler w/ delta expectations
This commit is contained in:
parent
ba10b47ae5
commit
cd63ba435e
3 changed files with 210 additions and 285 deletions
|
@ -541,31 +541,35 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
if (!params.tools.is_null() && !params.tools.empty()) {
|
||||||
std::vector<std::string> first_tool_rules;
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
std::vector<std::string> first_tool_rules;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
const auto & function = tool["function"];
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
std::string name = function["name"];
|
const auto & function = tool["function"];
|
||||||
auto parameters = function["parameters"];
|
std::string name = function["name"];
|
||||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
auto parameters = function["parameters"];
|
||||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||||
});
|
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
});
|
||||||
if (params.parallel_tool_calls) {
|
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
if (params.parallel_tool_calls) {
|
||||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||||
} else {
|
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||||
builder.add_rule("root", first_rule);
|
} else {
|
||||||
}
|
builder.add_rule("root", first_rule);
|
||||||
|
}
|
||||||
|
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
data.format = "functionary v3.2 tool calls";
|
||||||
|
} else {
|
||||||
|
data.format = "functionary v3.2 content-only";
|
||||||
|
}
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.format = "functionary v3.2 tool calls";
|
|
||||||
data.parser = [params](const std::string & input) {
|
data.parser = [params](const std::string & input) {
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
|
@ -763,21 +767,24 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
if (params.tools.is_null() || params.tool_choice == "none") {
|
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
|
||||||
return common_chat_init_without_tools(tmpl, params);
|
if (has_tools && !params.grammar.empty()) {
|
||||||
}
|
|
||||||
|
|
||||||
if (!params.grammar.empty()) {
|
|
||||||
throw std::runtime_error("Cannot specify grammar with tools");
|
throw std::runtime_error("Cannot specify grammar with tools");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & src = tmpl.source();
|
const auto & src = tmpl.source();
|
||||||
|
if (src.find(">>>all") != std::string::npos) {
|
||||||
|
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
|
||||||
|
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_tools) {
|
||||||
|
return common_chat_init_without_tools(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
if (src.find("<tool_call>") != std::string::npos) {
|
if (src.find("<tool_call>") != std::string::npos) {
|
||||||
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
if (src.find(">>>all") != std::string::npos) {
|
|
||||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
|
||||||
}
|
|
||||||
if (src.find("<|start_header_id|>") != std::string::npos
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
&& src.find("<function=") != std::string::npos) {
|
&& src.find("<function=") != std::string::npos) {
|
||||||
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
||||||
|
|
|
@ -65,7 +65,13 @@ class chat_template {
|
||||||
try_raw_render({
|
try_raw_render({
|
||||||
{{"role", "user"}, {"content", "Hey"}},
|
{{"role", "user"}, {"content", "Hey"}},
|
||||||
}, {
|
}, {
|
||||||
{{"name", "some_tool"}, {"parameters", {{"type", "string"}}}},
|
{
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", "some_tool"},
|
||||||
|
{"parameters", {{"type", "string"}}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
}, false).find("some_tool") != std::string::npos;
|
}, false).find("some_tool") != std::string::npos;
|
||||||
|
|
||||||
requires_object_arguments_ =
|
requires_object_arguments_ =
|
||||||
|
|
|
@ -10,6 +10,29 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
static common_chat_msg msg_from_json(const json & message) {
|
||||||
|
common_chat_msg ret {
|
||||||
|
"assistant",
|
||||||
|
"",
|
||||||
|
{},
|
||||||
|
};
|
||||||
|
if (message.contains("content") && !message.at("content").is_null()) {
|
||||||
|
ret.content = message.at("content").get<std::string>();
|
||||||
|
}
|
||||||
|
auto has_tool_calls = message.contains("tool_calls");
|
||||||
|
if (has_tool_calls) {
|
||||||
|
for (const auto & tc : message.at("tool_calls")) {
|
||||||
|
const auto & arguments = tc.at("function").at("arguments");
|
||||||
|
ret.tool_calls.push_back({
|
||||||
|
tc.at("function").at("name").get<std::string>(),
|
||||||
|
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
|
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
static void assert_equals(const T & expected, const T & actual) {
|
static void assert_equals(const T & expected, const T & actual) {
|
||||||
if (expected != actual) {
|
if (expected != actual) {
|
||||||
|
@ -139,184 +162,13 @@ const auto code_interpreter_tool = json::parse(R"({
|
||||||
const json tools = {special_function_tool, python_tool};
|
const json tools = {special_function_tool, python_tool};
|
||||||
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
||||||
|
|
||||||
// static void test_parsing() {
|
struct delta_data {
|
||||||
// json request = {
|
std::string delta;
|
||||||
// {"tools", tools}
|
std::string grammar;
|
||||||
// };
|
common_chat_parser parser;
|
||||||
|
};
|
||||||
|
|
||||||
// const auto fooBarCall = json {
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "foo"},
|
|
||||||
// {"arguments", dump({
|
|
||||||
// {"bar", 1}
|
|
||||||
// })},
|
|
||||||
// }}
|
|
||||||
// };
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
|
||||||
// "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}",
|
|
||||||
// "",
|
|
||||||
// json::array({fooBarCall}));
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
|
||||||
// "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}",
|
|
||||||
// "",
|
|
||||||
// json::array({fooBarCall}));
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools,
|
|
||||||
// "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
|
||||||
// "",
|
|
||||||
// json::array({fooBarCall}));
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
|
||||||
// ">>>python\n{\"code\": \"print('Hello, world!')\"}",
|
|
||||||
// "",
|
|
||||||
// json {{
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "python"},
|
|
||||||
// {"arguments", dump({
|
|
||||||
// {"code", "print('Hello, world!')"}
|
|
||||||
// })}
|
|
||||||
// }}
|
|
||||||
// }});
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
|
||||||
// ">>>special_function\n{\"arg1\": 1}\n ",
|
|
||||||
// "",
|
|
||||||
// json {{
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "special_function"},
|
|
||||||
// {"arguments", dump({
|
|
||||||
// {"arg1", 1}
|
|
||||||
// })}
|
|
||||||
// }}
|
|
||||||
// }});
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
|
||||||
// "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
|
||||||
// "Hello, world!",
|
|
||||||
// json {
|
|
||||||
// {
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "foo"},
|
|
||||||
// {"arguments", dump({
|
|
||||||
// {"arg1", 1}
|
|
||||||
// })}
|
|
||||||
// }}
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "bar"},
|
|
||||||
// {"arguments", dump({
|
|
||||||
// {"arg2", 2}
|
|
||||||
// })}
|
|
||||||
// }}
|
|
||||||
// },
|
|
||||||
// });
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
|
||||||
// "<function=test>{ } </function> ",
|
|
||||||
// " ",
|
|
||||||
// json {{
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "test"},
|
|
||||||
// {"arguments", "{}"}
|
|
||||||
// }}
|
|
||||||
// }});
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "<|python_tag|>this could be anything",
|
|
||||||
// "",
|
|
||||||
// json {{
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "python"},
|
|
||||||
// {"arguments", "this could be anything"},
|
|
||||||
// }}
|
|
||||||
// }});
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "I'm thinking<|python_tag|>",
|
|
||||||
// "I'm thinking",
|
|
||||||
// json {{
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"name", "python"},
|
|
||||||
// {"arguments", ""},
|
|
||||||
// }}
|
|
||||||
// }});
|
|
||||||
// auto special_function_call = json {
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"arguments", dump({{"arg1", 1}})},
|
|
||||||
// {"name", "special_function"},
|
|
||||||
// }},
|
|
||||||
// };
|
|
||||||
// auto special_function_call_with_id = json::parse(special_function_call.dump());
|
|
||||||
// special_function_call_with_id["id"] = "123456789";
|
|
||||||
|
|
||||||
// auto no_function_call = json::array();
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}",
|
|
||||||
// "",
|
|
||||||
// json::array({{
|
|
||||||
// {"type", "function"},
|
|
||||||
// {"function", {
|
|
||||||
// {"arguments", dump({{"code", "print('Hey')"}})},
|
|
||||||
// {"name", "python"},
|
|
||||||
// }}
|
|
||||||
// }}));
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
|
||||||
// "",
|
|
||||||
// json::array({special_function_call}));
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
|
||||||
// "",
|
|
||||||
// json::array({special_function_call}));
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
|
||||||
// "",
|
|
||||||
// json::array({special_function_call}));
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
|
||||||
// "",
|
|
||||||
// json::array({special_function_call}));
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
|
||||||
// "",
|
|
||||||
// json::array({special_function_call}));
|
|
||||||
|
|
||||||
// // No match: function unknown
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
|
||||||
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
|
||||||
// no_function_call);
|
|
||||||
// // No match: bad indentation
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
|
||||||
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
|
||||||
// no_function_call);
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
|
||||||
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
|
||||||
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
|
||||||
// no_function_call);
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools,
|
|
||||||
// "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
|
|
||||||
// "Bleh",
|
|
||||||
// json::array({special_function_call_with_id}));
|
|
||||||
|
|
||||||
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools,
|
|
||||||
// "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
|
|
||||||
// "Bleh",
|
|
||||||
// json::array({special_function_call}));
|
|
||||||
// }
|
|
||||||
|
|
||||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
|
||||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||||
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
||||||
|
|
||||||
|
@ -325,13 +177,17 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
||||||
params.messages = json::array();
|
params.messages = json::array();
|
||||||
params.messages.push_back(user_message);
|
params.messages.push_back(user_message);
|
||||||
params.tools = tools;
|
params.tools = tools;
|
||||||
std::string prefix = common_chat_init(tmpl, params).prompt;
|
auto prefix_data = common_chat_init(tmpl, params);
|
||||||
params.messages.push_back(delta_message);
|
params.messages.push_back(delta_message);
|
||||||
params.add_generation_prompt = false;
|
params.add_generation_prompt = false;
|
||||||
std::string full = common_chat_init(tmpl, params).prompt;
|
auto full_data = common_chat_init(tmpl, params);
|
||||||
|
|
||||||
|
std::string prefix = prefix_data.prompt;
|
||||||
|
std::string full = full_data.prompt;
|
||||||
|
|
||||||
// Check full starts with prefix
|
// Check full starts with prefix
|
||||||
if (full.find(prefix) != 0) {
|
if (full.find(prefix) != 0) {
|
||||||
|
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
|
||||||
throw std::runtime_error("Full message does not start with prefix");
|
throw std::runtime_error("Full message does not start with prefix");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,27 +206,12 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return delta;
|
return {delta, full_data.grammar, full_data.parser};
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) {
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) {
|
||||||
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||||
common_chat_msg expected_msg {
|
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||||
"assistant",
|
|
||||||
"",
|
|
||||||
{},
|
|
||||||
};
|
|
||||||
auto has_tool_calls = test_message.contains("tool_calls");
|
|
||||||
if (has_tool_calls) {
|
|
||||||
for (const auto & tc : test_message.at("tool_calls")) {
|
|
||||||
const auto & arguments = tc.at("function").at("arguments");
|
|
||||||
expected_msg.tool_calls.push_back({
|
|
||||||
tc.at("function").at("name").get<std::string>(),
|
|
||||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
||||||
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||||
// get the diff and try and parse it w/ the grammar.
|
// get the diff and try and parse it w/ the grammar.
|
||||||
|
@ -385,35 +226,37 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
params.parallel_tool_calls = true;
|
params.parallel_tool_calls = true;
|
||||||
params.messages = json {user_message, test_message};
|
params.messages = json {user_message, test_message};
|
||||||
params.tools = tools;
|
params.tools = tools;
|
||||||
auto chat_data = common_chat_init(tmpl, params);
|
|
||||||
// fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
|
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||||
if (has_tool_calls) {
|
std::cout << "Full delta:\n```\n" << data.delta << "\n```" << std::endl;
|
||||||
auto grammar = build_grammar(chat_data.grammar);
|
if (!expected_delta.empty()) {
|
||||||
|
assert_equals(expected_delta, data.delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!skip_parser_test) {
|
||||||
|
const auto msg = data.parser(data.delta);
|
||||||
|
assert_msg_equals(expected_msg, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!expected_msg.tool_calls.empty()) {
|
||||||
|
GGML_ASSERT(!data.grammar.empty());
|
||||||
|
}
|
||||||
|
if (!data.grammar.empty()) {
|
||||||
|
auto grammar = build_grammar(data.grammar);
|
||||||
if (!grammar) {
|
if (!grammar) {
|
||||||
throw std::runtime_error("Failed to build grammar");
|
throw std::runtime_error("Failed to build grammar");
|
||||||
}
|
}
|
||||||
|
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
|
||||||
if (!skip_grammar_test) {
|
if (!skip_grammar_test) {
|
||||||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
|
if (!match_string(data.delta, grammar.get())) {
|
||||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + "\n\nGrammar: " + data.grammar);
|
||||||
|
|
||||||
const auto msg = chat_data.parser(full_delta);
|
|
||||||
assert_msg_equals(expected_msg, msg);
|
|
||||||
|
|
||||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
|
||||||
{"role", "assistant"},
|
|
||||||
{"content", {}},
|
|
||||||
{"tool_calls", test_message.at("tool_calls")}
|
|
||||||
}, tools);
|
|
||||||
if (!match_string(content_less_delta, grammar.get())) {
|
|
||||||
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_grammars() {
|
static void test_template_output_parsers() {
|
||||||
auto text_message = json {
|
auto text_message = json {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
{"content", "Hello, world!"},
|
{"content", "Hello, world!"},
|
||||||
|
@ -465,6 +308,7 @@ static void test_grammars() {
|
||||||
|
|
||||||
common_chat_params tools_params = no_tools_params;
|
common_chat_params tools_params = no_tools_params;
|
||||||
tools_params.tools = json::array();
|
tools_params.tools = json::array();
|
||||||
|
tools_params.tools.push_back(special_function_tool);
|
||||||
|
|
||||||
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
|
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
|
||||||
auto data = common_chat_init(tmpl, params);
|
auto data = common_chat_init(tmpl, params);
|
||||||
|
@ -477,114 +321,182 @@ static void test_grammars() {
|
||||||
|
|
||||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
|
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||||
|
"{\n"
|
||||||
|
" \"response\": \"Hello, world!\"\n"
|
||||||
|
"}"));
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||||
|
"{\n"
|
||||||
|
" \"tool_calls\": [\n"
|
||||||
|
" {\n"
|
||||||
|
" \"name\": \"special_function\",\n"
|
||||||
|
" \"arguments\": {\n"
|
||||||
|
" \"arg1\": 1\n"
|
||||||
|
" },\n"
|
||||||
|
" \"id\": \"123456789\"\n"
|
||||||
|
" }\n"
|
||||||
|
" ]\n"
|
||||||
|
"}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|end|>" };
|
std::vector<std::string> end_tokens { "<|end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
|
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||||
|
"{\n"
|
||||||
|
" \"response\": \"Hello, world!\"\n"
|
||||||
|
"}"));
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||||
|
"{\n"
|
||||||
|
" \"tool_calls\": [\n"
|
||||||
|
" {\n"
|
||||||
|
" \"name\": \"special_function\",\n"
|
||||||
|
" \"arguments\": {\n"
|
||||||
|
" \"arg1\": 1\n"
|
||||||
|
" },\n"
|
||||||
|
" \"id\": \"123456789\"\n"
|
||||||
|
" }\n"
|
||||||
|
" ]\n"
|
||||||
|
"}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "</s>" };
|
std::vector<std::string> end_tokens { "</s>" };
|
||||||
|
|
||||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||||
|
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||||
|
/* skip_grammar_test= */ true);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
"<tool_call>\n"
|
||||||
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||||
|
"</tool_call>");
|
||||||
|
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||||
|
"<tool_call>\n"
|
||||||
|
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||||
|
"</tool_call>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"<tool_call>\n"
|
||||||
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||||
|
"</tool_call>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"<tool_call>\n"
|
||||||
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||||
|
"</tool_call>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
// test_template(tmpl, end_tokens, text_message, tools);
|
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||||
|
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
||||||
}
|
}
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"<function=special_function>{\"arg1\": 1}</function>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
|
"all\n"
|
||||||
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"special_function\n"
|
||||||
|
"{\"arg1\": 1}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||||
|
"```json\n"
|
||||||
|
"{\"arg1\": 1}\n"
|
||||||
|
"```<|tool▁call▁end|>");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// test_parsing();
|
// test_parsing();
|
||||||
test_grammars();
|
test_template_output_parsers();
|
||||||
|
|
||||||
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue