tool-call: remove nonsensical code_interpreter code

This commit is contained in:
ochafik 2025-01-27 14:19:20 +00:00
parent bddc1bebcc
commit da606d8d41
4 changed files with 165 additions and 240 deletions

View file

@ -1,6 +1,7 @@
#include "chat-handler.hpp" #include "chat-handler.hpp"
#include "chat-template.hpp" #include "chat-template.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "log.h"
#include "minja.hpp" #include "minja.hpp"
const common_grammar_options grammar_options { const common_grammar_options grammar_options {
@ -58,7 +59,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content. * Aggregates the prefix, suffix and in-between text into the content.
*/ */
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) { static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
std::smatch match; std::smatch match;
common_chat_msg result; common_chat_msg result;
@ -69,15 +70,10 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
std::vector<std::string> tool_names; std::vector<std::string> tool_names;
if (check_names) { if (check_names) {
for (const auto & tool : tools) { for (const auto & tool : tools) {
if (!tool.contains("type")) { if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
continue; continue;
} }
std::string type = tool.at("type"); tool_names.push_back(tool["function"]["name"]);
if (type == "function") {
tool_names.push_back(tool["function"]["name"]);
} else if (type == "code_interpreter") {
tool_names.push_back("python");
}
} }
} }
@ -102,7 +98,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
json arguments; json arguments;
if (!parse_json(it, end, arguments)) { if (!parse_json(it, end, arguments)) {
if (has_python && name == "python" && std::regex_match("", close_regex)) { if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end); std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""}); result.tool_calls.push_back({name, src, /* id= */ ""});
break; break;
@ -207,42 +203,22 @@ public:
} }
}; };
const auto python_tool = json::parse(R"({ static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
"type": "function",
"function": {
"name": "python",
"description": "an ipython interpreter",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute."
}
},
"required": ["code"]
}
}
})");
static void foreach_normalized_tool(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) { for (const auto & tool : tools) {
if (!tool.contains("type")) { if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
continue; continue;
} }
if (tool["type"] == "code_interpreter") { fn(tool);
fn(python_tool);
} else if (tool["type"] == "function" && tool.contains("function")) {
fn(tool);
}
} }
} }
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
auto tool_call_schemas = json::array(); auto tool_call_schemas = json::array();
foreach_normalized_tool(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool["function"];
auto tool_schema = json { auto tool_schema = json {
{"type", "object"}, {"type", "object"},
@ -348,10 +324,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
} }
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_normalized_tool(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool["function"];
schemas.push_back({ schemas.push_back({
{"type", "object"}, {"type", "object"},
@ -392,108 +369,91 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
return data; return data;
} }
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data; common_chat_data data;
auto has_python = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
auto add_tool = [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool["function"];
std::string name = function["name"]; std::string name = function["name"];
auto parameters = function["parameters"]; auto parameters = function["parameters"];
builder.resolve_refs(parameters); builder.resolve_refs(parameters);
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) { tool_rules.push_back(
has_python = true; builder.add_rule(
} else { name + "-call",
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + "\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
tool_rules.push_back( builder.add_schema(name + "-args", parameters) +
builder.add_rule( " \"}\""));
name + "-call", });
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + if (params.tool_choice != "required") {
builder.add_schema(name + "-args", parameters) + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
" \"}\""));
if (params.tool_choice != "required" && !eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
}
}
};
for (const auto & tool : params.tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
builtin_tools.push_back("code_interpreter");
has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) {
add_tool(tool);
}
} }
if (has_python) {
if (uses_python_tag) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
} else {
add_tool(python_tool);
}
}
if (params.tool_choice != "required" && eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
}
builder.add_rule("root", string_join(tool_rules, " | ")); builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options); }, grammar_options);
data.additional_stops.push_back("<|eom_id|>"); data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, { data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools}, {"builtin_tools", builtin_tools},
}); });
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
if (has_python && uses_python_tag) { static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ match[1].str(),
/* .id = */ "",
},
}
};
}
}
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}"); static std::regex close_regex("\\}");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag); auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
return res;
}); });
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data;
}
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
// auto add_tool = [&](const json & tool) {
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" "
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
});
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {});
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
return res;
});
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data; return data;
} }
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_normalized_tool(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool["function"];
schemas.push_back({ schemas.push_back({
{"type", "object"}, {"type", "object"},
@ -528,85 +488,92 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
} }
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data; common_chat_data data;
auto has_python = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
for (const auto & tool : params.tools) { foreach_function(params.tools, [&](const json & tool) {
if (!tool.contains("type")) { const auto & function = tool["function"];
continue; std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({name, /* .at_start = */ true});
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
} }
if (tool["type"] == "code_interpreter") { });
has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true});
data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false});
}
}
}
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
// Note: if there's a python rule, it needs to come last.
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
if (has_python && params.tool_choice != "required") {
data.grammar_triggers.push_back({"python\n", /* .at_start = */ true});
data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false});
}
if (params.parallel_tool_calls) { if (params.parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : "")); builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else { } else {
builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : "")); builder.add_rule("root", first_rule);
} }
}, grammar_options); }, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))"); static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python);
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
if (res.content.find("all\n") == 0) {
res.content = res.content.substr(4);
}
return res;
}); });
return data; return data;
} }
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
common_chat_data data; common_chat_data data;
json tools = params.tools.is_null() ? params.tools : json::array(); json tools = params.tools.is_null() ? params.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
auto has_python = false; foreach_function(params.tools, [&](const json & tool) {
for (const auto & tool : params.tools) { const auto & function = tool["function"];
if (!tool.contains("type")) { const auto & parameters = function["parameters"];
continue; std::string name = function["name"];
} if (name == "python" || name == "ipython") {
if (tool["type"] == "code_interpreter") { if (!parameters.contains("type")) {
has_python = true; throw std::runtime_error("Missing type in python tool");
} else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"];
std::string name = function["name"];
if (name == "python" || name == "ipython") {
has_python = true;
} else {
auto parameters = function["parameters"];
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
} }
has_raw_python = true;
auto type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
}
python_code_argument_name = it.key();
}
}
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
}
} else {
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
} }
} });
if (has_python) { if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") { if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
@ -620,18 +587,19 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}, grammar_options); }, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool. // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str();
return { return {
/* .role = */ "assistant", /* .role = */ "assistant",
/* .content = */ match.prefix().str(), /* .content = */ match.prefix().str(),
/* .tool_calls = */ { /* .tool_calls = */ {
{ {
/* .name = */ "python", /* .name = */ "python",
/* .arguments = */ match[1].str(), /* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(),
/* .id = */ "", /* .id = */ "",
}, },
} }
@ -639,17 +607,18 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
} }
static std::regex function_regex(R"(<function=(\w+)>)"); static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)"); static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, has_raw_python);
}); });
return data; return data;
} }
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)* // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_normalized_tool(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"]; const auto & function = tool["function"];
std::string name = function["name"]; std::string name = function["name"];
auto parameters = function["parameters"]; auto parameters = function["parameters"];
@ -719,6 +688,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
} }
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.parser = std::make_unique<text_chat_parser>(); data.parser = std::make_unique<text_chat_parser>();
@ -756,13 +726,11 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos; auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, if (uses_python_tag) {
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params);
// as it seems to be outputting some JSON. } else {
// TODO: make this conditional on a very small model (e.g. 1B / 3B). return common_chat_init_llama_3_2_tool_calls(tmpl, params);
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; }
return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json);
} }
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { // if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
// TODO: Command-R-Plus // TODO: Command-R-Plus

View file

@ -143,28 +143,10 @@ class chat_template {
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) { if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
actual_tools = json::array(); actual_tools = json::array();
for (const auto & tool : tools) { for (const auto & tool : tools) {
if (tool.contains("type") && tool.at("type") == "code_interpreter") { if (tool.contains("type") && tool.at("type") == "code_interpreter" && !supports_code_interpreter_) {
static const auto python_tool = json::parse(R"({ continue;
"type": "function",
"function": {
"name": "python",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to run in the ipython interpreter."
}
},
"required": ["code"]
}
}
})");
actual_tools.push_back(python_tool);
} else {
actual_tools.push_back(tool);
} }
actual_tools.push_back(tool);
} }
} else if (!tools.is_null()) { } else if (!tools.is_null()) {
actual_tools = tools; actual_tools = tools;
@ -295,10 +277,6 @@ class chat_template {
if (!tools.is_null()) { if (!tools.is_null()) {
auto tools_val = minja::Value(actual_tools); auto tools_val = minja::Value(actual_tools);
context->set("tools", tools_val); context->set("tools", tools_val);
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
context->set("builtin_tools", builtin_tools_val);
}
} }
if (!extra_context.is_null()) { if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) { for (auto & kv : extra_context.items()) {

View file

@ -333,7 +333,7 @@ struct server_task {
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
} }
if (data.contains("json_schema") && !data.contains("grammar")) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
try { try {
auto schema = json_value(data, "json_schema", json::object()); auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema); params.sampling.grammar = json_schema_to_grammar(schema);
@ -345,11 +345,6 @@ struct server_task {
} }
} }
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
{ {
params.sampling.logit_bias.clear(); params.sampling.logit_bias.clear();
params.ignore_eos = json_value(data, "ignore_eos", false); params.ignore_eos = json_value(data, "ignore_eos", false);

View file

@ -221,34 +221,22 @@ PYTHON_TOOL = {
} }
} }
CODE_INTEPRETER_TOOL = {
"type": "code_interpreter",
}
@pytest.mark.parametrize("template_name,tool,argument_key", [ @pytest.mark.parametrize("template_name,tool,argument_key", [
# TODO: fix special handling of python tool for these templates:
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None), # "code"), # TODO: fix ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None),
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None),
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None), # "code"), # TODO: fix ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None),
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None), # "code"), # TODO: fix ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None),
]) ])
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None): def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
n_predict = 512 n_predict = 512
@ -321,29 +309,19 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [ @pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
(PYTHON_TOOL, None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
(PYTHON_TOOL, None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(PYTHON_TOOL, None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
# TODO: fix these models # TODO: fix these models
# (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), (PYTHON_TOOL, '{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
# (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), # (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
# # (PYTHON_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), # (PYTHON_TOOL, None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
# # (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
# (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
# (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
# (PYTHON_TOOL, {"code": "print(\"hello world\")"}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
# (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)),
]) ])
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server global server
server.n_slots = 2 server.n_slots = 2
server.jinja = True server.jinja = True
@ -377,9 +355,15 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
assert tool["function"]["name"] == tool_call["function"]["name"] assert tool["function"]["name"] == tool_call["function"]["name"]
elif tool["type"] == "code_interpreter": elif tool["type"] == "code_interpreter":
assert re.match('i?python', tool_call["function"]["name"]) assert re.match('i?python', tool_call["function"]["name"])
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = tool_call["function"]["arguments"]
code = actual_arguments["code"] if expected_arguments is not None:
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' assert actual_arguments == expected_arguments
else:
actual_arguments = json.loads(actual_arguments)
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
code = actual_arguments["code"]
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
def test_logprobs(): def test_logprobs():