tool-call: fix functionary v3.1 required test
This commit is contained in:
parent
5ec4c5e4d3
commit
f7078cab36
3 changed files with 52 additions and 32 deletions
|
@ -102,6 +102,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
|||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
if (name == "python" && std::regex_match("", close_regex)) {
|
||||
std::string src(it, end);
|
||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
||||
break;
|
||||
}
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
|
@ -390,11 +395,11 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
common_chat_data data;
|
||||
auto has_python = false;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto has_python = false;
|
||||
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
|
@ -433,7 +438,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
|
|||
}
|
||||
}
|
||||
|
||||
if (has_python) {
|
||||
if (has_python && uses_python_tag) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
|
@ -453,8 +458,8 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
|
|||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
|
||||
if (uses_python_tag) {
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg {
|
||||
if (has_python && uses_python_tag) {
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
|
@ -521,10 +526,10 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
|
|||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_data data;
|
||||
|
||||
auto has_python = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
auto has_python = false;
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
|
@ -544,7 +549,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
|
|||
}
|
||||
}
|
||||
}
|
||||
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
// Note: if there's a python rule, it needs to come last.
|
||||
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
|
||||
if (has_python && params.tool_choice != "required") {
|
||||
|
@ -553,14 +558,14 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
|
|||
}
|
||||
if (params.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
|
||||
builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
|
||||
} else {
|
||||
builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : ""));
|
||||
builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : ""));
|
||||
}
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
|
@ -723,7 +728,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
|
|||
}
|
||||
|
||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
if (params.tools.is_null()) {
|
||||
if (params.tools.is_null() || params.tool_choice == "none") {
|
||||
return common_chat_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
|
@ -3788,11 +3788,14 @@ int main(int argc, char ** argv) {
|
|||
/* .tools = */ json_value(data, "tools", json()),
|
||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||
/* .json_schema = */ json_value(data, "json_schema", json()),
|
||||
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
|
||||
/* .stream = */ json_value(data, "json_schema", false),
|
||||
/* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false),
|
||||
/* .stream = */ json_value(data, "stream", false),
|
||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||
});
|
||||
if (data.contains("grammar")) {
|
||||
if (!chat_data.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot provide grammar and tools");
|
||||
}
|
||||
chat_data.grammar = data.at("grammar");
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -226,23 +226,31 @@ CODE_INTEPRETER_TOOL = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("template_name,n_predict,tool,argument_key", [
|
||||
("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, "code"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, "code"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, "code"),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, "success"),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, "code"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, "code"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, "code"),
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None),
|
||||
("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None),
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, None),
|
||||
("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, None),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, None),
|
||||
# # ("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, None),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None),
|
||||
])
|
||||
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, argument_key: str):
|
||||
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
|
||||
n_predict = 512
|
||||
global server
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
|
@ -267,9 +275,13 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
|
|||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
assert isinstance(actual_arguments, str)
|
||||
if argument_key is not None:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue