diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index effeeefda..abbabe069 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -102,6 +102,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri json arguments; if (!parse_json(it, end, arguments)) { + if (name == "python" && std::regex_match("", close_regex)) { + std::string src(it, end); + result.tool_calls.push_back({name, src, /* id= */ ""}); + break; + } throw std::runtime_error("Failed to parse json tool call arguments"); } if (!std::regex_search(it, end, match, close_regex)) { @@ -390,11 +395,11 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { auto builtin_tools = json {"wolfram_alpha", "brave_search"}; common_chat_data data; - + auto has_python = false; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - auto has_python = false; for (const auto & tool : params.tools) { if (!tool.contains("type")) { @@ -433,7 +438,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te } } - if (has_python) { + if (has_python && uses_python_tag) { tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); @@ -453,8 +458,8 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, { {"builtin_tools", builtin_tools}, }); - data.parser = std::make_unique([params, uses_python_tag](const std::string & input) -> common_chat_msg { - if (uses_python_tag) { + data.parser = std::make_unique([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg { + if (has_python && uses_python_tag) { static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { @@ -521,10 +526,10 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_data data; + auto has_python = false; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; - auto has_python = false; for (const auto & tool : params.tools) { if (!tool.contains("type")) { continue; @@ -544,7 +549,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const } } } - auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; + auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; // Note: if there's a python rule, it needs to come last. auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*"); if (has_python && params.tool_choice != "required") { @@ -553,14 +558,14 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const } if (params.parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; - builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : "")); + builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : "")); } else { - builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : "")); + builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : "")); } }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([params, has_python](const std::string & input) -> common_chat_msg { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); @@ -723,7 +728,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat } common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) { - if (params.tools.is_null()) { + if (params.tools.is_null() || params.tool_choice == "none") { return common_chat_init_without_tools(tmpl, params); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 925f4f8ef..a3f99ac26 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3788,11 +3788,14 @@ int main(int argc, char ** argv) { /* .tools = */ json_value(data, "tools", json()), /* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")), /* .json_schema = */ json_value(data, "json_schema", json()), - /* .parallel_tool_calls = */ json_value(data, "json_schema", true), - /* .stream = */ json_value(data, "json_schema", false), + /* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false), + /* .stream = */ json_value(data, "stream", false), /* .grammar = */ json_value(data, "grammar", std::string("")), }); if (data.contains("grammar")) { + if (!chat_data.grammar.empty()) { + throw std::runtime_error("Cannot provide grammar and tools"); + } chat_data.grammar = data.at("grammar"); } } else { diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 399d8b937..3e7e3233d 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -226,23 +226,31 @@ CODE_INTEPRETER_TOOL = { } -@pytest.mark.parametrize("template_name,n_predict,tool,argument_key", [ - ("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, "code"), - ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, "code"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, "success"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, "code"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, "success"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, "code"), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, "success"), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, "code"), - ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, "success"), - ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, "code"), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, "success"), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, "code"), +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None), + ("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None), + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, None), + ("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, None), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None), + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, None), + # # ("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None), + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, None), + ("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None), ]) -def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, argument_key: str): +def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None): + n_predict = 512 global server # server = ServerPreset.stories15m_moe() server.jinja = True @@ -267,9 +275,13 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert tool["function"]["name"] == tool_call["function"]["name"] - actual_arguments = json.loads(tool_call["function"]["arguments"]) - assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [