tool-call: remove nonsensical code_interpreter code
This commit is contained in:
parent
bddc1bebcc
commit
da606d8d41
4 changed files with 165 additions and 240 deletions
|
@ -1,6 +1,7 @@
|
||||||
#include "chat-handler.hpp"
|
#include "chat-handler.hpp"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "log.h"
|
||||||
#include "minja.hpp"
|
#include "minja.hpp"
|
||||||
|
|
||||||
const common_grammar_options grammar_options {
|
const common_grammar_options grammar_options {
|
||||||
|
@ -58,7 +59,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
||||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||||
* Aggregates the prefix, suffix and in-between text into the content.
|
* Aggregates the prefix, suffix and in-between text into the content.
|
||||||
*/
|
*/
|
||||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) {
|
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
|
||||||
common_chat_msg result;
|
common_chat_msg result;
|
||||||
|
@ -69,15 +70,10 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
std::vector<std::string> tool_names;
|
std::vector<std::string> tool_names;
|
||||||
if (check_names) {
|
if (check_names) {
|
||||||
for (const auto & tool : tools) {
|
for (const auto & tool : tools) {
|
||||||
if (!tool.contains("type")) {
|
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string type = tool.at("type");
|
tool_names.push_back(tool["function"]["name"]);
|
||||||
if (type == "function") {
|
|
||||||
tool_names.push_back(tool["function"]["name"]);
|
|
||||||
} else if (type == "code_interpreter") {
|
|
||||||
tool_names.push_back("python");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +98,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
|
|
||||||
json arguments;
|
json arguments;
|
||||||
if (!parse_json(it, end, arguments)) {
|
if (!parse_json(it, end, arguments)) {
|
||||||
if (has_python && name == "python" && std::regex_match("", close_regex)) {
|
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
|
||||||
std::string src(it, end);
|
std::string src(it, end);
|
||||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
result.tool_calls.push_back({name, src, /* id= */ ""});
|
||||||
break;
|
break;
|
||||||
|
@ -207,42 +203,22 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto python_tool = json::parse(R"({
|
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "python",
|
|
||||||
"description": "an ipython interpreter",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"code": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Python code to execute."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["code"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})");
|
|
||||||
|
|
||||||
static void foreach_normalized_tool(const json & tools, const std::function<void(const json &)> & fn) {
|
|
||||||
for (const auto & tool : tools) {
|
for (const auto & tool : tools) {
|
||||||
if (!tool.contains("type")) {
|
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||||
|
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (tool["type"] == "code_interpreter") {
|
fn(tool);
|
||||||
fn(python_tool);
|
|
||||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
|
||||||
fn(tool);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
auto tool_schema = json {
|
auto tool_schema = json {
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -348,10 +324,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
schemas.push_back({
|
schemas.push_back({
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -392,108 +369,91 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
auto has_python = false;
|
|
||||||
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
auto add_tool = [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
builder.resolve_refs(parameters);
|
builder.resolve_refs(parameters);
|
||||||
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
|
tool_rules.push_back(
|
||||||
has_python = true;
|
builder.add_rule(
|
||||||
} else {
|
name + "-call",
|
||||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
"\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||||
tool_rules.push_back(
|
builder.add_schema(name + "-args", parameters) +
|
||||||
builder.add_rule(
|
" \"}\""));
|
||||||
name + "-call",
|
});
|
||||||
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
if (params.tool_choice != "required") {
|
||||||
builder.add_schema(name + "-args", parameters) +
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
" \"}\""));
|
|
||||||
if (params.tool_choice != "required" && !eagerly_match_any_json) {
|
|
||||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
|
|
||||||
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
|
|
||||||
// Note that c++11's regex doesn't support partial matches, otherwise it would make
|
|
||||||
// sense to add support for trigger regexes to the antiprompt mechanism.
|
|
||||||
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
|
|
||||||
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
|
||||||
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
|
||||||
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (const auto & tool : params.tools) {
|
|
||||||
if (!tool.contains("type")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tool["type"] == "code_interpreter") {
|
|
||||||
builtin_tools.push_back("code_interpreter");
|
|
||||||
has_python = true;
|
|
||||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
|
||||||
add_tool(tool);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_python) {
|
|
||||||
if (uses_python_tag) {
|
|
||||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
|
||||||
if (params.tool_choice != "required") {
|
|
||||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
add_tool(python_tool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.tool_choice != "required" && eagerly_match_any_json) {
|
|
||||||
data.grammar_triggers.push_back({"{\"", /* .at_start = */ true});
|
|
||||||
data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true});
|
|
||||||
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
|
||||||
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
|
||||||
{"builtin_tools", builtin_tools},
|
{"builtin_tools", builtin_tools},
|
||||||
});
|
});
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
if (has_python && uses_python_tag) {
|
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
|
||||||
std::smatch match;
|
|
||||||
if (std::regex_search(input, match, python_tag_regex)) {
|
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ match.prefix().str(),
|
|
||||||
/* .tool_calls = */ {
|
|
||||||
{
|
|
||||||
/* .name = */ "python",
|
|
||||||
/* .arguments = */ match[1].str(),
|
|
||||||
/* .id = */ "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
|
||||||
static std::regex close_regex("\\}");
|
static std::regex close_regex("\\}");
|
||||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag);
|
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||||
|
return res;
|
||||||
});
|
});
|
||||||
|
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
|
common_chat_data data;
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
|
// auto add_tool = [&](const json & tool) {
|
||||||
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
tool_rules.push_back(
|
||||||
|
builder.add_rule(
|
||||||
|
name + "-call",
|
||||||
|
"\"{\" "
|
||||||
|
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
|
||||||
|
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||||
|
builder.add_schema(name + "-args", parameters) +
|
||||||
|
" \"}\""));
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
|
}, grammar_options);
|
||||||
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {});
|
||||||
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
|
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||||
|
static std::regex close_regex("\\}");
|
||||||
|
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
schemas.push_back({
|
schemas.push_back({
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -528,85 +488,92 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
auto has_python = false;
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> first_tool_rules;
|
std::vector<std::string> first_tool_rules;
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
for (const auto & tool : params.tools) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
if (!tool.contains("type")) {
|
const auto & function = tool["function"];
|
||||||
continue;
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||||
|
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||||
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||||
}
|
}
|
||||||
if (tool["type"] == "code_interpreter") {
|
});
|
||||||
has_python = true;
|
|
||||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
|
||||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
|
||||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
|
|
||||||
if (params.tool_choice != "required") {
|
|
||||||
data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true});
|
|
||||||
data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||||
// Note: if there's a python rule, it needs to come last.
|
|
||||||
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
|
|
||||||
if (has_python && params.tool_choice != "required") {
|
|
||||||
data.grammar_triggers.push_back({"python\n", /* .at_start = */ true});
|
|
||||||
data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false});
|
|
||||||
}
|
|
||||||
if (params.parallel_tool_calls) {
|
if (params.parallel_tool_calls) {
|
||||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||||
builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
|
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||||
} else {
|
} else {
|
||||||
builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : ""));
|
builder.add_rule("root", first_rule);
|
||||||
}
|
}
|
||||||
|
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python);
|
|
||||||
|
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||||
|
if (res.content.find("all\n") == 0) {
|
||||||
|
res.content = res.content.substr(4);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
});
|
});
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
// TODO: handle tool {type: code_interpreter} as python
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
json tools = params.tools.is_null() ? params.tools : json::array();
|
json tools = params.tools.is_null() ? params.tools : json::array();
|
||||||
|
std::string python_code_argument_name;
|
||||||
|
auto has_raw_python = false;
|
||||||
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
auto has_python = false;
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
for (const auto & tool : params.tools) {
|
const auto & function = tool["function"];
|
||||||
if (!tool.contains("type")) {
|
const auto & parameters = function["parameters"];
|
||||||
continue;
|
std::string name = function["name"];
|
||||||
}
|
if (name == "python" || name == "ipython") {
|
||||||
if (tool["type"] == "code_interpreter") {
|
if (!parameters.contains("type")) {
|
||||||
has_python = true;
|
throw std::runtime_error("Missing type in python tool");
|
||||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
if (name == "python" || name == "ipython") {
|
|
||||||
has_python = true;
|
|
||||||
} else {
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
|
||||||
}
|
}
|
||||||
|
has_raw_python = true;
|
||||||
|
auto type = parameters.at("type");
|
||||||
|
if (type == "object") {
|
||||||
|
auto properties = parameters.at("properties");
|
||||||
|
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||||
|
if (it.value().at("type") == "string") {
|
||||||
|
if (!python_code_argument_name.empty()) {
|
||||||
|
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||||
|
}
|
||||||
|
python_code_argument_name = it.key();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (python_code_argument_name.empty()) {
|
||||||
|
throw std::runtime_error("No string argument found in python tool");
|
||||||
|
}
|
||||||
|
} else if (type != "string") {
|
||||||
|
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
if (has_python) {
|
if (has_raw_python) {
|
||||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||||
if (params.tool_choice != "required") {
|
if (params.tool_choice != "required") {
|
||||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
|
@ -620,18 +587,19 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
if (std::regex_search(input, match, python_tag_regex)) {
|
if (std::regex_search(input, match, python_tag_regex)) {
|
||||||
|
auto code = match[1].str();
|
||||||
return {
|
return {
|
||||||
/* .role = */ "assistant",
|
/* .role = */ "assistant",
|
||||||
/* .content = */ match.prefix().str(),
|
/* .content = */ match.prefix().str(),
|
||||||
/* .tool_calls = */ {
|
/* .tool_calls = */ {
|
||||||
{
|
{
|
||||||
/* .name = */ "python",
|
/* .name = */ "python",
|
||||||
/* .arguments = */ match[1].str(),
|
/* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(),
|
||||||
/* .id = */ "",
|
/* .id = */ "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -639,17 +607,18 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
||||||
}
|
}
|
||||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
static std::regex close_regex(R"(</function>)");
|
static std::regex close_regex(R"(</function>)");
|
||||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false);
|
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||||
});
|
});
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
@ -719,6 +688,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.parser = std::make_unique<text_chat_parser>();
|
data.parser = std::make_unique<text_chat_parser>();
|
||||||
|
@ -756,13 +726,11 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
||||||
|
|
||||||
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
if (uses_python_tag) {
|
||||||
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params);
|
||||||
// as it seems to be outputting some JSON.
|
} else {
|
||||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
|
||||||
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
}
|
||||||
|
|
||||||
return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json);
|
|
||||||
}
|
}
|
||||||
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||||
// TODO: Command-R-Plus
|
// TODO: Command-R-Plus
|
||||||
|
|
|
@ -143,28 +143,10 @@ class chat_template {
|
||||||
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
|
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
|
||||||
actual_tools = json::array();
|
actual_tools = json::array();
|
||||||
for (const auto & tool : tools) {
|
for (const auto & tool : tools) {
|
||||||
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
|
if (tool.contains("type") && tool.at("type") == "code_interpreter" && !supports_code_interpreter_) {
|
||||||
static const auto python_tool = json::parse(R"({
|
continue;
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "python",
|
|
||||||
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"code": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The code to run in the ipython interpreter."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["code"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})");
|
|
||||||
actual_tools.push_back(python_tool);
|
|
||||||
} else {
|
|
||||||
actual_tools.push_back(tool);
|
|
||||||
}
|
}
|
||||||
|
actual_tools.push_back(tool);
|
||||||
}
|
}
|
||||||
} else if (!tools.is_null()) {
|
} else if (!tools.is_null()) {
|
||||||
actual_tools = tools;
|
actual_tools = tools;
|
||||||
|
@ -295,10 +277,6 @@ class chat_template {
|
||||||
if (!tools.is_null()) {
|
if (!tools.is_null()) {
|
||||||
auto tools_val = minja::Value(actual_tools);
|
auto tools_val = minja::Value(actual_tools);
|
||||||
context->set("tools", tools_val);
|
context->set("tools", tools_val);
|
||||||
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
|
|
||||||
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
|
||||||
context->set("builtin_tools", builtin_tools_val);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!extra_context.is_null()) {
|
if (!extra_context.is_null()) {
|
||||||
for (auto & kv : extra_context.items()) {
|
for (auto & kv : extra_context.items()) {
|
||||||
|
|
|
@ -333,7 +333,7 @@ struct server_task {
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||||
}
|
}
|
||||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
try {
|
try {
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||||
|
@ -345,11 +345,6 @@ struct server_task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// process "json_schema" and "grammar"
|
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
||||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
|
||||||
}
|
|
||||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
||||||
{
|
{
|
||||||
params.sampling.logit_bias.clear();
|
params.sampling.logit_bias.clear();
|
||||||
params.ignore_eos = json_value(data, "ignore_eos", false);
|
params.ignore_eos = json_value(data, "ignore_eos", false);
|
||||||
|
|
|
@ -221,34 +221,22 @@ PYTHON_TOOL = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CODE_INTEPRETER_TOOL = {
|
|
||||||
"type": "code_interpreter",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||||
# TODO: fix special handling of python tool for these templates:
|
|
||||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None), # "code"), # TODO: fix
|
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||||
("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None),
|
|
||||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||||
("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None),
|
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None),
|
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
|
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None),
|
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None), # "code"), # TODO: fix
|
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None),
|
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None), # "code"), # TODO: fix
|
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None),
|
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
|
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
|
@ -321,29 +309,19 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
|
||||||
|
(PYTHON_TOOL, None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||||
|
(PYTHON_TOOL, None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
|
(PYTHON_TOOL, None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
|
(PYTHON_TOOL, None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
(PYTHON_TOOL, None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||||
|
(PYTHON_TOOL, None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
|
(PYTHON_TOOL, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
||||||
# TODO: fix these models
|
# TODO: fix these models
|
||||||
# (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
(PYTHON_TOOL, '{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
# (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
# # (PYTHON_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
# # (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
# (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
# (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
# (PYTHON_TOOL, {"code": "print(\"hello world\")"}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
# (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
|
||||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
|
||||||
(PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
|
||||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
|
||||||
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
|
||||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
|
||||||
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
|
||||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
|
||||||
(PYTHON_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
|
||||||
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)),
|
|
||||||
])
|
])
|
||||||
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
server.n_slots = 2
|
server.n_slots = 2
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
|
@ -377,9 +355,15 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
|
||||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||||
elif tool["type"] == "code_interpreter":
|
elif tool["type"] == "code_interpreter":
|
||||||
assert re.match('i?python', tool_call["function"]["name"])
|
assert re.match('i?python', tool_call["function"]["name"])
|
||||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
actual_arguments = tool_call["function"]["arguments"]
|
||||||
code = actual_arguments["code"]
|
if expected_arguments is not None:
|
||||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
assert actual_arguments == expected_arguments
|
||||||
|
else:
|
||||||
|
actual_arguments = json.loads(actual_arguments)
|
||||||
|
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||||
|
code = actual_arguments["code"]
|
||||||
|
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||||
|
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||||
|
|
||||||
|
|
||||||
def test_logprobs():
|
def test_logprobs():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue