refactor test-chat-handler

This commit is contained in:
ochafik 2025-01-29 04:00:01 +00:00
parent 18d5a1b2ca
commit 4a1e8e9f91

View file

@ -319,32 +319,10 @@ static void test_template_output_parsers() {
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end_of_turn>" };
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
"{\n"
" \"response\": \"Hello, world!\"\n"
"}"));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|end|>" };
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
"{\n"
@ -368,16 +346,20 @@ static void test_template_output_parsers() {
std::vector<std::string> end_tokens { "</s>" };
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
@ -388,35 +370,13 @@ static void test_template_output_parsers() {
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
@ -424,44 +384,36 @@ static void test_template_output_parsers() {
"<|python_tag|>python.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"all\n"
"Hello, world!", /* skip_grammar_test= */ true);
@ -474,6 +426,7 @@ static void test_template_output_parsers() {
std::vector<std::string> end_tokens { "<|eot_id|>" };
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
@ -484,6 +437,7 @@ static void test_template_output_parsers() {
std::vector<std::string> end_tokens { "<end▁of▁sentence>" };
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,