refactor test-chat-handler
This commit is contained in:
parent
18d5a1b2ca
commit
4a1e8e9f91
1 changed files with 25 additions and 71 deletions
|
@ -319,32 +319,10 @@ static void test_template_output_parsers() {
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
|
||||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
|
||||||
"{\n"
|
|
||||||
" \"response\": \"Hello, world!\"\n"
|
|
||||||
"}"));
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
|
||||||
"{\n"
|
|
||||||
" \"tool_calls\": [\n"
|
|
||||||
" {\n"
|
|
||||||
" \"name\": \"special_function\",\n"
|
|
||||||
" \"arguments\": {\n"
|
|
||||||
" \"arg1\": 1\n"
|
|
||||||
" },\n"
|
|
||||||
" \"id\": \"123456789\"\n"
|
|
||||||
" }\n"
|
|
||||||
" ]\n"
|
|
||||||
"}");
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|end|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
|
||||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
|
||||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||||
"{\n"
|
"{\n"
|
||||||
|
@ -368,16 +346,20 @@ static void test_template_output_parsers() {
|
||||||
std::vector<std::string> end_tokens { "</s>" };
|
std::vector<std::string> end_tokens { "</s>" };
|
||||||
|
|
||||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||||
/* skip_grammar_test= */ true);
|
/* skip_grammar_test= */ true);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"<tool_call>\n"
|
"<tool_call>\n"
|
||||||
|
@ -388,35 +370,13 @@ static void test_template_output_parsers() {
|
||||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||||
"</tool_call>");
|
"</tool_call>");
|
||||||
}
|
}
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
||||||
"<tool_call>\n"
|
|
||||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
||||||
"</tool_call>");
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
||||||
"<tool_call>\n"
|
|
||||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
||||||
"</tool_call>");
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||||
|
@ -424,44 +384,36 @@ static void test_template_output_parsers() {
|
||||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"<function=special_function>{\"arg1\": 1}</function>");
|
"<function=special_function>{\"arg1\": 1}</function>");
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"all\n"
|
"all\n"
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
@ -474,6 +426,7 @@ static void test_template_output_parsers() {
|
||||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
@ -484,6 +437,7 @@ static void test_template_output_parsers() {
|
||||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue