Rehabilitate test_format_detection
This commit is contained in:
parent
add9124115
commit
fa065eb095
4 changed files with 45 additions and 28 deletions
|
@ -288,6 +288,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||
|
||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "generic tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
|
@ -355,7 +356,8 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
}
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||
data.format = "mistral nemo tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
});
|
||||
return data;
|
||||
|
@ -397,6 +399,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
data.format = "llama 3.1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
|
@ -452,7 +455,8 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
|||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
data.format = "llama 3.2 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
|
@ -482,7 +486,8 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
|||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
data.format = "deepseek r1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
|
@ -524,13 +529,14 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
|||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
}
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||
data.format = "firefunction v2 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
|
@ -562,7 +568,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
|
|||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
data.format = "functionary v3.2 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
|
@ -629,6 +636,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
|||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "functionary v3.1 llama 3.1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
|
@ -682,6 +690,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "hermes 2 pro tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
|
@ -733,6 +742,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
|
|||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "content-only";
|
||||
data.parser = std::make_unique<text_chat_parser>();
|
||||
if (!params.json_schema.is_null()) {
|
||||
if (!params.grammar.empty()) {
|
||||
|
@ -759,7 +769,7 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
|||
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
||||
}
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
return common_chat_init_functionary_v3_llama_3_tool_call(tmpl, params);
|
||||
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
|
||||
}
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
|
|
|
@ -41,6 +41,7 @@ struct common_chat_data {
|
|||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::unique_ptr<class common_chat_parser> parser;
|
||||
std::string format; // For debugging and testing.
|
||||
};
|
||||
|
||||
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
||||
|
|
|
@ -208,7 +208,6 @@ class chat_template {
|
|||
arguments = json::parse(arguments.get<std::string>());
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
arguments = arguments;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -298,26 +298,33 @@ const json tools = {special_function_tool, python_tool};
|
|||
// json::array({special_function_call}));
|
||||
// }
|
||||
|
||||
// static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) {
|
||||
// const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
||||
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||
// std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
||||
// assert_equals(expected, tool_call_style);
|
||||
// }
|
||||
static void test_format_detection() {
|
||||
common_chat_params no_tools_params;
|
||||
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
|
||||
|
||||
// static void test_tool_call_style_detection() {
|
||||
// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1);
|
||||
// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3);
|
||||
// test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2);
|
||||
// test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1);
|
||||
// test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2);
|
||||
// test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
||||
// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
||||
// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
||||
// test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS);
|
||||
// test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO);
|
||||
// test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC);
|
||||
// }
|
||||
common_chat_params tools_params = no_tools_params;
|
||||
tools_params.tools = json::array();
|
||||
|
||||
auto describe = [](const std::string & template_file, const common_chat_params & params) {
|
||||
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
||||
auto data = common_chat_init(tmpl, params);
|
||||
return data.format;
|
||||
};
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", tools_params));
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", tools_params));
|
||||
assert_equals(std::string("firefunction v2 tool calls"), describe("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", tools_params));
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", tools_params));
|
||||
assert_equals(std::string("llama 3.2 tool calls"), describe("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", tools_params));
|
||||
assert_equals(std::string("mistral nemo tool calls"), describe("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", tools_params));
|
||||
assert_equals(std::string("deepseek r1 tool calls"), describe("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe("tests/chat/templates/google-gemma-7b-it.jinja", tools_params));
|
||||
assert_equals(std::string("content-only"), describe("tests/chat/templates/google-gemma-7b-it.jinja", no_tools_params));
|
||||
// assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params));
|
||||
}
|
||||
|
||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||
|
@ -498,7 +505,7 @@ static void test_grammars() {
|
|||
}
|
||||
|
||||
int main() {
|
||||
// test_tool_call_style_detection();
|
||||
test_format_detection();
|
||||
// test_parsing();
|
||||
test_grammars();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue