tool-call: fix functionary v3.1 required test

This commit is contained in:
ochafik 2025-01-26 23:23:09 +00:00
parent 5ec4c5e4d3
commit f7078cab36
3 changed files with 52 additions and 32 deletions

View file

@ -102,6 +102,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
json arguments; json arguments;
if (!parse_json(it, end, arguments)) { if (!parse_json(it, end, arguments)) {
if (name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""});
break;
}
throw std::runtime_error("Failed to parse json tool call arguments"); throw std::runtime_error("Failed to parse json tool call arguments");
} }
if (!std::regex_search(it, end, match, close_regex)) { if (!std::regex_search(it, end, match, close_regex)) {
@ -390,11 +395,11 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data; common_chat_data data;
auto has_python = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
auto has_python = false;
for (const auto & tool : params.tools) { for (const auto & tool : params.tools) {
if (!tool.contains("type")) { if (!tool.contains("type")) {
@ -433,7 +438,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
} }
} }
if (has_python) { if (has_python && uses_python_tag) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") { if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
@ -453,8 +458,8 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, { data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools}, {"builtin_tools", builtin_tools},
}); });
data.parser = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg {
if (uses_python_tag) { if (has_python && uses_python_tag) {
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
@ -521,10 +526,10 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data; common_chat_data data;
auto has_python = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
auto has_python = false;
for (const auto & tool : params.tools) { for (const auto & tool : params.tools) {
if (!tool.contains("type")) { if (!tool.contains("type")) {
continue; continue;
@ -544,7 +549,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
} }
} }
} }
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
// Note: if there's a python rule, it needs to come last. // Note: if there's a python rule, it needs to come last.
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*"); auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
if (has_python && params.tool_choice != "required") { if (has_python && params.tool_choice != "required") {
@ -553,14 +558,14 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
} }
if (params.parallel_tool_calls) { if (params.parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : "")); builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
} else { } else {
builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : "")); builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : ""));
} }
}, grammar_options); }, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))"); static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
@ -723,7 +728,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
} }
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) { common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
if (params.tools.is_null()) { if (params.tools.is_null() || params.tool_choice == "none") {
return common_chat_init_without_tools(tmpl, params); return common_chat_init_without_tools(tmpl, params);
} }

View file

@ -3788,11 +3788,14 @@ int main(int argc, char ** argv) {
/* .tools = */ json_value(data, "tools", json()), /* .tools = */ json_value(data, "tools", json()),
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")), /* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
/* .json_schema = */ json_value(data, "json_schema", json()), /* .json_schema = */ json_value(data, "json_schema", json()),
/* .parallel_tool_calls = */ json_value(data, "json_schema", true), /* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false),
/* .stream = */ json_value(data, "json_schema", false), /* .stream = */ json_value(data, "stream", false),
/* .grammar = */ json_value(data, "grammar", std::string("")), /* .grammar = */ json_value(data, "grammar", std::string("")),
}); });
if (data.contains("grammar")) { if (data.contains("grammar")) {
if (!chat_data.grammar.empty()) {
throw std::runtime_error("Cannot provide grammar and tools");
}
chat_data.grammar = data.at("grammar"); chat_data.grammar = data.at("grammar");
} }
} else { } else {

View file

@ -226,23 +226,31 @@ CODE_INTEPRETER_TOOL = {
} }
@pytest.mark.parametrize("template_name,n_predict,tool,argument_key", [ @pytest.mark.parametrize("template_name,tool,argument_key", [
("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None),
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, None),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, "success"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, "code"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, None),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, "success"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, "code"), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, "success"), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, "code"), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, "success"), ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, "code"), ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None),
("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None),
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, None),
# # ("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, None),
("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None),
]) ])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, argument_key: str): def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
n_predict = 512
global server global server
# server = ServerPreset.stories15m_moe() # server = ServerPreset.stories15m_moe()
server.jinja = True server.jinja = True
@ -267,9 +275,13 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"] expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"]) assert expected_function_name == tool_call["function"]["name"]
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" actual_arguments = tool_call["function"]["arguments"]
assert isinstance(actual_arguments, str)
if argument_key is not None:
actual_arguments = json.loads(actual_arguments)
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [