Update test-chat-handler.cpp
This commit is contained in:
parent
b565ab2ab1
commit
2d607f1a68
1 changed files with 152 additions and 108 deletions
|
@ -298,34 +298,6 @@ const json tools = {special_function_tool, python_tool};
|
||||||
// json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
// }
|
// }
|
||||||
|
|
||||||
static void test_format_detection() {
|
|
||||||
common_chat_params no_tools_params;
|
|
||||||
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
|
||||||
|
|
||||||
common_chat_params tools_params = no_tools_params;
|
|
||||||
tools_params.tools = json::array();
|
|
||||||
|
|
||||||
auto describe = [](const std::string & template_file, const common_chat_params & params) {
|
|
||||||
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
|
||||||
auto data = common_chat_init(tmpl, params);
|
|
||||||
return data.format;
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", tools_params));
|
|
||||||
assert_equals(std::string("functionary v3.2 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", tools_params));
|
|
||||||
assert_equals(std::string("firefunction v2 tool calls"), describe("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", tools_params));
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", tools_params));
|
|
||||||
assert_equals(std::string("llama 3.2 tool calls"), describe("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", tools_params));
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", tools_params));
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", tools_params));
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", tools_params));
|
|
||||||
assert_equals(std::string("mistral nemo tool calls"), describe("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", tools_params));
|
|
||||||
assert_equals(std::string("deepseek r1 tool calls"), describe("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", tools_params));
|
|
||||||
assert_equals(std::string("generic tool calls"), describe("tests/chat/templates/google-gemma-7b-it.jinja", tools_params));
|
|
||||||
assert_equals(std::string("content-only"), describe("tests/chat/templates/google-gemma-7b-it.jinja", no_tools_params));
|
|
||||||
// assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params));
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||||
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
||||||
|
@ -363,14 +335,16 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
||||||
return delta;
|
return delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) {
|
||||||
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||||
common_chat_msg expected_msg {
|
common_chat_msg expected_msg {
|
||||||
"assistant",
|
"assistant",
|
||||||
"",
|
"",
|
||||||
{},
|
{},
|
||||||
};
|
};
|
||||||
for (const auto & tc : tool_calling_message.at("tool_calls")) {
|
auto has_tool_calls = test_message.contains("tool_calls");
|
||||||
|
if (has_tool_calls) {
|
||||||
|
for (const auto & tc : test_message.at("tool_calls")) {
|
||||||
const auto & arguments = tc.at("function").at("arguments");
|
const auto & arguments = tc.at("function").at("arguments");
|
||||||
expected_msg.tool_calls.push_back({
|
expected_msg.tool_calls.push_back({
|
||||||
tc.at("function").at("name").get<std::string>(),
|
tc.at("function").at("name").get<std::string>(),
|
||||||
|
@ -378,6 +352,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||||
// get the diff and try and parse it w/ the grammar.
|
// get the diff and try and parse it w/ the grammar.
|
||||||
|
@ -386,19 +361,22 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
{"content", "Hello, world!"}
|
{"content", "Hello, world!"}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
for (const auto & tool_choice : json({"auto", "required"})) {
|
||||||
common_chat_params params;
|
common_chat_params params;
|
||||||
|
params.tool_choice = tool_choice;
|
||||||
params.parallel_tool_calls = true;
|
params.parallel_tool_calls = true;
|
||||||
params.messages = json {user_message, tool_calling_message};
|
params.messages = json {user_message, test_message};
|
||||||
params.tools = tools;
|
params.tools = tools;
|
||||||
auto chat_data = common_chat_init(tmpl, params);
|
auto chat_data = common_chat_init(tmpl, params);
|
||||||
fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
|
// fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
|
||||||
|
if (has_tool_calls) {
|
||||||
auto grammar = build_grammar(chat_data.grammar);
|
auto grammar = build_grammar(chat_data.grammar);
|
||||||
if (!grammar) {
|
if (!grammar) {
|
||||||
throw std::runtime_error("Failed to build grammar");
|
throw std::runtime_error("Failed to build grammar");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!skip_grammar_test) {
|
if (!skip_grammar_test) {
|
||||||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||||
|
|
||||||
const auto msg = chat_data.parser->parse_final(full_delta);
|
const auto msg = chat_data.parser->parse_final(full_delta);
|
||||||
|
@ -407,15 +385,21 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
{"content", {}},
|
{"content", {}},
|
||||||
{"tool_calls", tool_calling_message.at("tool_calls")}
|
{"tool_calls", test_message.at("tool_calls")}
|
||||||
}, tools);
|
}, tools);
|
||||||
if (!match_string(content_less_delta, grammar.get())) {
|
if (!match_string(content_less_delta, grammar.get())) {
|
||||||
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
|
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void test_grammars() {
|
static void test_grammars() {
|
||||||
|
auto text_message = json {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", "Hello, world!"},
|
||||||
|
};
|
||||||
auto tool_call_message = json {
|
auto tool_call_message = json {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
{"content", {}},
|
{"content", {}},
|
||||||
|
@ -444,68 +428,128 @@ static void test_grammars() {
|
||||||
}}}
|
}}}
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
common_chat_params no_tools_params;
|
||||||
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||||
}
|
|
||||||
{
|
common_chat_params tools_params = no_tools_params;
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
tools_params.tools = json::array();
|
||||||
// assert_equals(tmpl.requires_object_arguments_, true);
|
|
||||||
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
|
||||||
test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
auto data = common_chat_init(tmpl, params);
|
||||||
}
|
return data.format;
|
||||||
{
|
};
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
|
||||||
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||||
test_template(tmpl, { "<end_of_turn>" }, tool_call_message_with_id, tools);
|
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
|
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||||
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
|
std::vector<std::string> end_tokens { "<|end|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "</s>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||||
test_template(tmpl, { "<|end▁of▁sentence|>" }, tool_call_message, tools);
|
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
test_format_detection();
|
|
||||||
// test_parsing();
|
// test_parsing();
|
||||||
test_grammars();
|
test_grammars();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue