Add cli mode to test-chat to generate template summaries markdown
This commit is contained in:
parent
84bc083faf
commit
2b2456978a
1 changed files with 71 additions and 35 deletions
|
@ -1,3 +1,13 @@
|
|||
/*
|
||||
Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
|
||||
|
||||
Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
|
||||
e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
|
||||
|
||||
cmake -B build && cmake --build build --parallel && \
|
||||
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
|
||||
*/
|
||||
#include "chat.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "llama-grammar.h"
|
||||
|
@ -44,7 +54,7 @@ static void assert_equals(const T & expected, const T & actual) {
|
|||
}
|
||||
|
||||
static std::string read_file(const std::string &path) {
|
||||
std::cout << "# Reading: " << path << std::endl << std::flush;
|
||||
std::cerr << "# Reading: " << path << std::endl << std::flush;
|
||||
std::ifstream fs(path, std::ios_base::binary);
|
||||
if (!fs.is_open()) {
|
||||
fs = std::ifstream("../" + path, std::ios_base::binary);
|
||||
|
@ -168,13 +178,15 @@ struct delta_data {
|
|||
common_chat_parser parser;
|
||||
};
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
|
||||
common_chat_inputs inputs;
|
||||
inputs.parallel_tool_calls = true;
|
||||
inputs.messages = json::array();
|
||||
inputs.messages.push_back(user_message);
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
inputs.messages.push_back(delta_message);
|
||||
inputs.add_generation_prompt = false;
|
||||
auto params_full = common_chat_params_init(tmpl, inputs);
|
||||
|
@ -220,7 +232,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
};
|
||||
|
||||
for (const auto & tool_choice : json({"auto", "required"})) {
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
|
||||
if (!expected_delta.empty()) {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
|
@ -248,6 +260,10 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
}
|
||||
}
|
||||
|
||||
static std::string describe(const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||
return common_chat_params_init(tmpl, params).format;
|
||||
}
|
||||
|
||||
static void test_template_output_parsers() {
|
||||
auto text_message = json {
|
||||
{"role", "assistant"},
|
||||
|
@ -295,29 +311,25 @@ static void test_template_output_parsers() {
|
|||
};
|
||||
|
||||
|
||||
common_chat_inputs no_tools_params;
|
||||
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||
common_chat_inputs inputs_no_tools;
|
||||
inputs_no_tools.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||
|
||||
common_chat_inputs tools_params = no_tools_params;
|
||||
tools_params.tools = json::array();
|
||||
tools_params.tools.push_back(special_function_tool);
|
||||
|
||||
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||
return common_chat_params_init(tmpl, params).format;
|
||||
};
|
||||
common_chat_inputs inputs_tools = inputs_no_tools;
|
||||
inputs_tools.tools = json::array();
|
||||
inputs_tools.tools.push_back(special_function_tool);
|
||||
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
|
||||
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser(
|
||||
"{\n"
|
||||
" \"response\": \"Hello, world!\"\n"
|
||||
"}"));
|
||||
|
@ -339,7 +351,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "</s>" };
|
||||
|
||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
|
@ -351,11 +363,11 @@ static void test_template_output_parsers() {
|
|||
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
|
||||
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
|
@ -372,9 +384,9 @@ static void test_template_output_parsers() {
|
|||
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
|
||||
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
|
@ -389,7 +401,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -401,7 +413,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -413,8 +425,8 @@ static void test_template_output_parsers() {
|
|||
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"all\n"
|
||||
|
@ -428,7 +440,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -440,7 +452,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -452,9 +464,33 @@ static void test_template_output_parsers() {
|
|||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_template_output_parsers();
|
||||
int main(int argc, char **argv) {
|
||||
#ifndef _WIN32
|
||||
if (argc > 1) {
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||
inputs.tools = json::array({special_function_tool});
|
||||
|
||||
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
||||
std::cout << "| Template | Format |\n";
|
||||
std::cout << "|----------|--------|\n";
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string path = argv[i];
|
||||
if (path.rfind(".jinja") != path.size() - 6) {
|
||||
std::cerr << "Skipping non-jinja file: " << path << std::endl;
|
||||
continue;
|
||||
}
|
||||
common_chat_template tmpl(read_file(path), "", "");
|
||||
auto parts = string_split(path, "/");
|
||||
auto name = parts[parts.size() - 1];
|
||||
std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n";
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
test_template_output_parsers();
|
||||
std::cout << "\n[chat] All tests passed!" << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue