refactor test-chat-handler
This commit is contained in:
parent
18d5a1b2ca
commit
4a1e8e9f91
1 changed files with 25 additions and 71 deletions
|
@ -319,32 +319,10 @@ static void test_template_output_parsers() {
|
|||
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||
"{\n"
|
||||
" \"response\": \"Hello, world!\"\n"
|
||||
"}"));
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
" \"name\": \"special_function\",\n"
|
||||
" \"arguments\": {\n"
|
||||
" \"arg1\": 1\n"
|
||||
" },\n"
|
||||
" \"id\": \"123456789\"\n"
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end|>" };
|
||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||
"{\n"
|
||||
|
@ -368,16 +346,20 @@ static void test_template_output_parsers() {
|
|||
std::vector<std::string> end_tokens { "</s>" };
|
||||
|
||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||
/* skip_grammar_test= */ true);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
|
@ -388,35 +370,13 @@ static void test_template_output_parsers() {
|
|||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
|
@ -424,44 +384,36 @@ static void test_template_output_parsers() {
|
|||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"all\n"
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -474,6 +426,7 @@ static void test_template_output_parsers() {
|
|||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
|
@ -484,6 +437,7 @@ static void test_template_output_parsers() {
|
|||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue