Rehabilitate test_format_detection

This commit is contained in:
ochafik 2025-01-27 20:46:03 +00:00
parent add9124115
commit fa065eb095
4 changed files with 45 additions and 28 deletions

View file

@ -288,6 +288,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "generic tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
json data = json::parse(input);
common_chat_msg result;
@ -355,7 +356,8 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
data.format = "mistral nemo tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
});
return data;
@ -397,6 +399,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
{"builtin_tools", builtin_tools},
});
data.format = "llama 3.1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
@ -452,7 +455,8 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
data.format = "llama 3.2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
@ -482,7 +486,8 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
data.format = "deepseek r1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
@ -524,13 +529,14 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
data.format = "firefunction v2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
});
return data;
}
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
@ -562,7 +568,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
data.format = "functionary v3.2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
@ -629,6 +636,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "functionary v3.1 llama 3.1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
@ -682,6 +690,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "hermes 2 pro tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
@ -733,6 +742,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only";
data.parser = std::make_unique<text_chat_parser>();
if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) {
@ -759,7 +769,7 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
}
if (src.find(">>>all") != std::string::npos) {
return common_chat_init_functionary_v3_llama_3_tool_call(tmpl, params);
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
}
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {

View file

@ -41,6 +41,7 @@ struct common_chat_data {
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> additional_stops;
std::unique_ptr<class common_chat_parser> parser;
std::string format; // For debugging and testing.
};
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);

View file

@ -208,7 +208,6 @@ class chat_template {
arguments = json::parse(arguments.get<std::string>());
} catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
arguments = arguments;
}
}
}

View file

@ -298,26 +298,33 @@ const json tools = {special_function_tool, python_tool};
// json::array({special_function_call}));
// }
// static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) {
// const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
// auto tool_call_style = common_tool_call_style_detect(tmpl);
// std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
// assert_equals(expected, tool_call_style);
// }
static void test_format_detection() {
common_chat_params no_tools_params;
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
// static void test_tool_call_style_detection() {
// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1);
// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3);
// test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2);
// test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1);
// test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2);
// test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
// test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS);
// test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO);
// test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC);
// }
common_chat_params tools_params = no_tools_params;
tools_params.tools = json::array();
auto describe = [](const std::string & template_file, const common_chat_params & params) {
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
auto data = common_chat_init(tmpl, params);
return data.format;
};
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", tools_params));
assert_equals(std::string("functionary v3.2 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", tools_params));
assert_equals(std::string("firefunction v2 tool calls"), describe("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", tools_params));
assert_equals(std::string("llama 3.1 tool calls"), describe("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", tools_params));
assert_equals(std::string("llama 3.2 tool calls"), describe("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", tools_params));
assert_equals(std::string("mistral nemo tool calls"), describe("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", tools_params));
assert_equals(std::string("deepseek r1 tool calls"), describe("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", tools_params));
assert_equals(std::string("generic tool calls"), describe("tests/chat/templates/google-gemma-7b-it.jinja", tools_params));
assert_equals(std::string("content-only"), describe("tests/chat/templates/google-gemma-7b-it.jinja", no_tools_params));
// assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params));
}
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
@ -498,7 +505,7 @@ static void test_grammars() {
}
int main() {
// test_tool_call_style_detection();
test_format_detection();
// test_parsing();
test_grammars();