diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 4e215a459..ad71fd9e2 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -185,9 +185,9 @@ static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std:: }; } } - static std::regex function_regex("(?:^|\\n)\\{\"name\": \"([^\"]+)\", \"parameters\": "); + static std::regex function_regex("\\{[\\s\\n\\r]*\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); - return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false); + return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { @@ -270,7 +270,7 @@ llama_tool_call_handler llama_tool_call_handler_init( tool_rules.push_back( builder.add_rule( name + "-call", - "\"\\n\"? \"{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + "\"\\n\"? \"{\" space \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + builder.add_schema(name + "-args", parameters) + " \"}\"")); if (allow_content && !eagerly_match_any_json) { diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 4450f9aa9..f7e5e2027 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -228,19 +228,45 @@ static void test_parsing() { {"arguments", dump({{"code", ""}})} }} }}); - test_parse_tool_call(llama_tool_call_style::Llama31, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - "", - json {{ + auto just_special_function_call = json {{ {"type", "function"}, {"function", { {"name", "special_function"}, {"arguments", dump({{"arg1", 1}})} }} - }}); + }}; + auto no_function_call = json::array(); + + test_parse_tool_call(llama_tool_call_style::Llama31, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + "", + just_special_function_call); + test_parse_tool_call(llama_tool_call_style::Llama31, tools, + "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + "", + just_special_function_call); + test_parse_tool_call(llama_tool_call_style::Llama31, tools, + "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + "", + just_special_function_call); + test_parse_tool_call(llama_tool_call_style::Llama31, tools, + "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + "", + just_special_function_call); + // No match: function unknown test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", - "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); + "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", + no_function_call); + // No match: bad indentation + test_parse_tool_call(llama_tool_call_style::Llama31, tools, + "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", + "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", + no_function_call); + test_parse_tool_call(llama_tool_call_style::Llama31, tools, + "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", + "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", + no_function_call); } static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { @@ -334,9 +360,9 @@ static void test_grammars() { } int main() { - test_grammars(); - test_parsing(); test_tool_call_style_detection(); + test_parsing(); + test_grammars(); std::cout << "[tool-call] All tests passed!" << std::endl; return 0;