Add cli mode to test-chat to generate template summaries markdown

This commit is contained in:
ochafik 2025-01-29 22:33:16 +00:00
parent 84bc083faf
commit 2b2456978a

View file

@ -1,3 +1,13 @@
/*
Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
cmake -B build && cmake --build build --parallel && \
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
*/
#include "chat.hpp"
#include "chat-template.hpp"
#include "llama-grammar.h"
@ -44,7 +54,7 @@ static void assert_equals(const T & expected, const T & actual) {
}
static std::string read_file(const std::string &path) {
std::cout << "# Reading: " << path << std::endl << std::flush;
std::cerr << "# Reading: " << path << std::endl << std::flush;
std::ifstream fs(path, std::ios_base::binary);
if (!fs.is_open()) {
fs = std::ifstream("../" + path, std::ios_base::binary);
@ -168,13 +178,15 @@ struct delta_data {
common_chat_parser parser;
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
common_chat_inputs inputs;
inputs.parallel_tool_calls = true;
inputs.messages = json::array();
inputs.messages.push_back(user_message);
inputs.tools = tools;
inputs.tool_choice = tool_choice;
auto params_prefix = common_chat_params_init(tmpl, inputs);
inputs.messages.push_back(delta_message);
inputs.add_generation_prompt = false;
auto params_full = common_chat_params_init(tmpl, inputs);
@ -220,7 +232,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
};
for (const auto & tool_choice : json({"auto", "required"})) {
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta);
}
@ -248,6 +260,10 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
}
}
static std::string describe(const common_chat_template & tmpl, const common_chat_inputs & params) {
return common_chat_params_init(tmpl, params).format;
}
static void test_template_output_parsers() {
auto text_message = json {
{"role", "assistant"},
@ -295,29 +311,25 @@ static void test_template_output_parsers() {
};
common_chat_inputs no_tools_params;
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
common_chat_inputs inputs_no_tools;
inputs_no_tools.messages = {{{"role", "user"}, {"content", "Hey"}}};
common_chat_inputs tools_params = no_tools_params;
tools_params.tools = json::array();
tools_params.tools.push_back(special_function_tool);
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
return common_chat_params_init(tmpl, params).format;
};
common_chat_inputs inputs_tools = inputs_no_tools;
inputs_tools.tools = json::array();
inputs_tools.tools.push_back(special_function_tool);
{
const common_chat_template tmpl(read_file(
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end_of_turn>" };
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file(
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser(
"{\n"
" \"response\": \"Hello, world!\"\n"
"}"));
@ -339,7 +351,7 @@ static void test_template_output_parsers() {
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "</s>" };
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
@ -351,11 +363,11 @@ static void test_template_output_parsers() {
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
@ -372,9 +384,9 @@ static void test_template_output_parsers() {
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file(
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@ -389,7 +401,7 @@ static void test_template_output_parsers() {
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
@ -401,7 +413,7 @@ static void test_template_output_parsers() {
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
@ -413,8 +425,8 @@ static void test_template_output_parsers() {
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"all\n"
@ -428,7 +440,7 @@ static void test_template_output_parsers() {
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eot_id|>" };
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
@ -440,7 +452,7 @@ static void test_template_output_parsers() {
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end▁of▁sentence>" };
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
@ -452,9 +464,33 @@ static void test_template_output_parsers() {
}
}
int main() {
test_template_output_parsers();
int main(int argc, char **argv) {
#ifndef _WIN32
if (argc > 1) {
common_chat_inputs inputs;
inputs.messages = {{{"role", "user"}, {"content", "Hey"}}};
inputs.tools = json::array({special_function_tool});
std::cout << "\n[tool-call] All tests passed!" << std::endl;
std::cout << "| Template | Format |\n";
std::cout << "|----------|--------|\n";
for (int i = 1; i < argc; i++) {
std::string path = argv[i];
if (path.rfind(".jinja") != path.size() - 6) {
std::cerr << "Skipping non-jinja file: " << path << std::endl;
continue;
}
common_chat_template tmpl(read_file(path), "", "");
auto parts = string_split(path, "/");
auto name = parts[parts.size() - 1];
std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n";
}
}
else
#endif
{
test_template_output_parsers();
std::cout << "\n[chat] All tests passed!" << std::endl;
}
return 0;
}