tool-call: remove nonsensical code_interpreter code
This commit is contained in:
parent
bddc1bebcc
commit
da606d8d41
4 changed files with 165 additions and 240 deletions
|
@ -1,6 +1,7 @@
|
|||
#include "chat-handler.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "minja.hpp"
|
||||
|
||||
const common_grammar_options grammar_options {
|
||||
|
@ -58,7 +59,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) {
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
|
@ -69,15 +70,10 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
|||
std::vector<std::string> tool_names;
|
||||
if (check_names) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
std::string type = tool.at("type");
|
||||
if (type == "function") {
|
||||
tool_names.push_back(tool["function"]["name"]);
|
||||
} else if (type == "code_interpreter") {
|
||||
tool_names.push_back("python");
|
||||
}
|
||||
tool_names.push_back(tool["function"]["name"]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +98,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
|||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
if (has_python && name == "python" && std::regex_match("", close_regex)) {
|
||||
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
|
||||
std::string src(it, end);
|
||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
||||
break;
|
||||
|
@ -207,42 +203,22 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
const auto python_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python",
|
||||
"description": "an ipython interpreter",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
})");
|
||||
|
||||
static void foreach_normalized_tool(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
fn(python_tool);
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
fn(tool);
|
||||
}
|
||||
fn(tool);
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
|
@ -348,10 +324,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
|
@ -392,108 +369,91 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
||||
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
common_chat_data data;
|
||||
auto has_python = false;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto add_tool = [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
|
||||
has_python = true;
|
||||
} else {
|
||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (params.tool_choice != "required" && !eagerly_match_any_json) {
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
|
||||
// Note that c++11's regex doesn't support partial matches, otherwise it would make
|
||||
// sense to add support for trigger regexes to the antiprompt mechanism.
|
||||
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
};
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
builtin_tools.push_back("code_interpreter");
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
add_tool(tool);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
});
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
|
||||
if (has_python) {
|
||||
if (uses_python_tag) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
} else {
|
||||
add_tool(python_tool);
|
||||
}
|
||||
}
|
||||
|
||||
if (params.tool_choice != "required" && eagerly_match_any_json) {
|
||||
data.grammar_triggers.push_back({"{\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
||||
}
|
||||
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg {
|
||||
if (has_python && uses_python_tag) {
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ match[1].str(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag);
|
||||
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
return res;
|
||||
});
|
||||
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
// auto add_tool = [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" "
|
||||
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
|
||||
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||
}
|
||||
});
|
||||
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {});
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
return res;
|
||||
});
|
||||
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
|
@ -528,85 +488,92 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_data data;
|
||||
|
||||
auto has_python = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
// Note: if there's a python rule, it needs to come last.
|
||||
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
|
||||
if (has_python && params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"python\n", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false});
|
||||
}
|
||||
if (params.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
|
||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||
} else {
|
||||
builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : ""));
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python);
|
||||
|
||||
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||
if (res.content.find("all\n") == 0) {
|
||||
res.content = res.content.substr(4);
|
||||
}
|
||||
return res;
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
// TODO: handle tool {type: code_interpreter} as python
|
||||
common_chat_data data;
|
||||
json tools = params.tools.is_null() ? params.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
auto has_python = false;
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
if (name == "python" || name == "ipython") {
|
||||
has_python = true;
|
||||
} else {
|
||||
auto parameters = function["parameters"];
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
const auto & parameters = function["parameters"];
|
||||
std::string name = function["name"];
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
auto type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
} else {
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
}
|
||||
}
|
||||
if (has_python) {
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
|
@ -620,18 +587,19 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
|||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
auto code = match[1].str();
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ match[1].str(),
|
||||
/* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
|
@ -639,17 +607,18 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
|||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false);
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
@ -719,6 +688,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.parser = std::make_unique<text_chat_parser>();
|
||||
|
@ -756,13 +726,11 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
|||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
||||
|
||||
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
||||
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
||||
// as it seems to be outputting some JSON.
|
||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||
|
||||
return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json);
|
||||
if (uses_python_tag) {
|
||||
return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params);
|
||||
} else {
|
||||
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
|
||||
}
|
||||
}
|
||||
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
// TODO: Command-R-Plus
|
||||
|
|
|
@ -143,28 +143,10 @@ class chat_template {
|
|||
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
|
||||
actual_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
|
||||
static const auto python_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python",
|
||||
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to run in the ipython interpreter."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
})");
|
||||
actual_tools.push_back(python_tool);
|
||||
} else {
|
||||
actual_tools.push_back(tool);
|
||||
if (tool.contains("type") && tool.at("type") == "code_interpreter" && !supports_code_interpreter_) {
|
||||
continue;
|
||||
}
|
||||
actual_tools.push_back(tool);
|
||||
}
|
||||
} else if (!tools.is_null()) {
|
||||
actual_tools = tools;
|
||||
|
@ -295,10 +277,6 @@ class chat_template {
|
|||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(actual_tools);
|
||||
context->set("tools", tools_val);
|
||||
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
|
||||
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
||||
context->set("builtin_tools", builtin_tools_val);
|
||||
}
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
|
|
|
@ -333,7 +333,7 @@ struct server_task {
|
|||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
|
@ -345,11 +345,6 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
params.ignore_eos = json_value(data, "ignore_eos", false);
|
||||
|
|
|
@ -221,34 +221,22 @@ PYTHON_TOOL = {
|
|||
}
|
||||
}
|
||||
|
||||
CODE_INTEPRETER_TOOL = {
|
||||
"type": "code_interpreter",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
# TODO: fix special handling of python tool for these templates:
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None), # "code"), # TODO: fix
|
||||
("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None), # "code"), # TODO: fix
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None), # "code"), # TODO: fix
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
|
||||
n_predict = 512
|
||||
|
@ -321,29 +309,19 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
|||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
|
||||
(PYTHON_TOOL, None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
||||
# TODO: fix these models
|
||||
# (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
# (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
# # (PYTHON_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# # (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (PYTHON_TOOL, {"code": "print(\"hello world\")"}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
# (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)),
|
||||
(PYTHON_TOOL, '{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
])
|
||||
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 2
|
||||
server.jinja = True
|
||||
|
@ -377,9 +355,15 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
|
|||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
elif tool["type"] == "code_interpreter":
|
||||
assert re.match('i?python', tool_call["function"]["name"])
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
code = actual_arguments["code"]
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments is not None:
|
||||
assert actual_arguments == expected_arguments
|
||||
else:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
|
||||
|
||||
def test_logprobs():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue