tool-call: fix llama 3.x tc parsing when there are spaces before "name"

This commit is contained in:
Olivier Chafik 2024-10-03 21:46:41 +01:00
parent da02397f7f
commit 366efc8a18
2 changed files with 37 additions and 11 deletions

View file

@ -185,9 +185,9 @@ static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std::
};
}
}
static std::regex function_regex("(?:^|\\n)\\{\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex function_regex("\\{[\\s\\n\\r]*\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
}
static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
@ -270,7 +270,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
"\"\\n\"? \"{\" space \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (allow_content && !eagerly_match_any_json) {

View file

@ -228,19 +228,45 @@ static void test_parsing() {
{"arguments", dump({{"code", ""}})}
}}
}});
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"",
json {{
auto just_special_function_call = json {{
{"type", "function"},
{"function", {
{"name", "special_function"},
{"arguments", dump({{"arg1", 1}})}
}}
}});
}};
auto no_function_call = json::array();
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"",
just_special_function_call);
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"",
just_special_function_call);
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"",
just_special_function_call);
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"",
just_special_function_call);
// No match: function unknown
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
no_function_call);
// No match: bad indentation
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
no_function_call);
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
no_function_call);
}
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
@ -334,9 +360,9 @@ static void test_grammars() {
}
int main() {
test_grammars();
test_parsing();
test_tool_call_style_detection();
test_parsing();
test_grammars();
std::cout << "[tool-call] All tests passed!" << std::endl;
return 0;