WIP chat handlers
This commit is contained in:
parent
46415d7a51
commit
36ed106f84
13 changed files with 1071 additions and 999 deletions
3
Makefile
3
Makefile
|
@ -1363,10 +1363,11 @@ llama-server: \
|
|||
examples/server/httplib.h \
|
||||
examples/server/index.html.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
common/chat-handler.cpp \
|
||||
common/chat-handler.hpp \
|
||||
common/chat-template.hpp \
|
||||
common/json.hpp \
|
||||
common/minja.hpp \
|
||||
common/tool-call.h \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||
|
|
|
@ -56,6 +56,8 @@ add_library(${TARGET} STATIC
|
|||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-handler.cpp
|
||||
chat-handler.hpp
|
||||
chat-template.hpp
|
||||
common.cpp
|
||||
common.h
|
||||
|
@ -72,7 +74,6 @@ add_library(${TARGET} STATIC
|
|||
sampling.h
|
||||
speculative.cpp
|
||||
speculative.h
|
||||
tool-call.cpp
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
|
|
720
common/chat-handler.cpp
Normal file
720
common/chat-handler.cpp
Normal file
|
@ -0,0 +1,720 @@
|
|||
#include "chat-handler.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "minja.hpp"
|
||||
|
||||
const common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ false,
|
||||
// /* .compact_spaces = */ true,
|
||||
};
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
return false;
|
||||
}
|
||||
bool null() override { return true; }
|
||||
bool boolean(bool) override { return true; }
|
||||
bool number_integer(number_integer_t) override { return true; }
|
||||
bool number_unsigned(number_unsigned_t) override { return true; }
|
||||
bool number_float(number_float_t, const string_t &) override { return true; }
|
||||
bool string(string_t &) override { return true; }
|
||||
bool binary(binary_t &) override { return true; }
|
||||
bool start_object(std::size_t) override { return true; }
|
||||
bool key(string_t &) override { return true; }
|
||||
bool end_object() override { return true; }
|
||||
bool start_array(std::size_t) override { return true; }
|
||||
bool end_array() override { return true; }
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
std::string::const_iterator temptative_end;
|
||||
if (err_loc.found_error) {
|
||||
temptative_end = it + err_loc.position;
|
||||
} else {
|
||||
temptative_end = end;
|
||||
}
|
||||
std::string json_sub {it, temptative_end};
|
||||
try {
|
||||
out = json::parse(json_sub);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
auto end = input.end();
|
||||
auto it = input.begin();
|
||||
|
||||
std::vector<std::string> tool_names;
|
||||
if (check_names) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
std::string type = tool.at("type");
|
||||
if (type == "function") {
|
||||
tool_names.push_back(tool["function"]["name"]);
|
||||
} else if (type == "code_interpreter") {
|
||||
tool_names.push_back("ipython");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (it != end) {
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(it, end, function_regex);
|
||||
if (rit == rend) {
|
||||
fprintf(stderr, "No more tool calls found\n");
|
||||
result.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
auto name = rit->str(1);
|
||||
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
|
||||
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
|
||||
result.content += std::string(it, rit->suffix().first);
|
||||
break;
|
||||
}
|
||||
|
||||
result.content += std::string(it, rit->prefix().second);
|
||||
it = rit->suffix().first;
|
||||
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
auto content_end = input.find(prefix);
|
||||
size_t tc_start = std::string::npos;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
if (content_end == std::string::npos) {
|
||||
result.content = input;
|
||||
} else {
|
||||
tc_start = content_end + prefix.size() - rstrip_prefix;
|
||||
result.content = input.substr(0, content_end);
|
||||
auto tool_calls = json::parse(input.substr(tc_start));
|
||||
process_tool_calls(tool_calls);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||
json messages_with_system = messages;
|
||||
|
||||
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||
std::string existing_system = messages_with_system.at(0).at("content");
|
||||
messages_with_system[0] = json {
|
||||
{"role", "system"},
|
||||
{"content", existing_system + "\n" + system_prompt},
|
||||
};
|
||||
} else {
|
||||
messages_with_system.insert(messages_with_system.begin(), json {
|
||||
{"role", "system"},
|
||||
{"content", system_prompt},
|
||||
});
|
||||
}
|
||||
return messages_with_system;
|
||||
}
|
||||
|
||||
class text_chat_parser : public common_chat_parser {
|
||||
public:
|
||||
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||
return parse_final(input);
|
||||
}
|
||||
|
||||
common_chat_msg parse_final(const std::string & input) override {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
class monolithic_chat_parser : public common_chat_parser {
|
||||
|
||||
std::string input_buffer_;
|
||||
std::function<common_chat_msg(const std::string & input)> parse_final_;
|
||||
|
||||
public:
|
||||
monolithic_chat_parser(const std::function<common_chat_msg(const std::string & input)> & parse_final) : parse_final_(parse_final) {}
|
||||
|
||||
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||
input_buffer_ += input;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
common_chat_msg parse_final(const std::string & input) override {
|
||||
input_buffer_ += input;
|
||||
auto out = parse_final_(input_buffer_);
|
||||
input_buffer_.clear();
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
for (const auto & tool : params.tools) {
|
||||
const auto & function = tool["function"];
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
};
|
||||
if (function.contains("description")) {
|
||||
tool_schema["description"] = function["description"];
|
||||
}
|
||||
if (params.parallel_tool_calls) {
|
||||
tool_schema["properties"]["id"] = {
|
||||
{"type", "string"},
|
||||
{"minLength", 4},
|
||||
};
|
||||
tool_schema["required"].push_back("id");
|
||||
}
|
||||
tool_call_schemas.emplace_back(tool_schema);
|
||||
}
|
||||
const auto tool_call =
|
||||
params.parallel_tool_calls
|
||||
? json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_calls", {
|
||||
{"type", "array"},
|
||||
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||
{"anyOf", tool_call_schemas},
|
||||
}},
|
||||
{"minItems", 1},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_calls"})},
|
||||
}
|
||||
: json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||
{"anyOf", tool_call_schemas},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_call"})},
|
||||
};
|
||||
const auto schema =
|
||||
params.tool_choice != "required"
|
||||
? json {
|
||||
{"anyOf", json::array({
|
||||
tool_call,
|
||||
{
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"response", params.json_schema.is_null()
|
||||
? json {{"type", "string"}}
|
||||
: params.json_schema
|
||||
},
|
||||
}},
|
||||
{"required", json::array({"response"})},
|
||||
},
|
||||
})}
|
||||
}
|
||||
: tool_call;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_schema("root", schema);
|
||||
}, grammar_options);
|
||||
|
||||
// TODO: add schema to system prompt.
|
||||
auto tweaked_messages = add_system(
|
||||
params.messages,
|
||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||
|
||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
for (const auto & tool : params.tools) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
||||
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
{"id", {
|
||||
{"type", "string"},
|
||||
// Nemo's template expects a 9-character alphanumeric ID.
|
||||
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
}
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!params.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
}
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
builtin_tools.push_back("code_interpreter");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_data data;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto has_python = false;
|
||||
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
|
||||
has_python = true;
|
||||
} else {
|
||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (params.tool_choice != "required" && !eagerly_match_any_json) {
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
|
||||
// Note that c++11's regex doesn't support partial matches, otherwise it would make
|
||||
// sense to add support for trigger regexes to the antiprompt mechanism.
|
||||
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_python) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.tool_choice != "required" && eagerly_match_any_json) {
|
||||
data.grammar_triggers.push_back({"{\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
||||
}
|
||||
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
|
||||
if (uses_python_tag) {
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ match[1].str(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
for (const auto & tool : params.tools) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
}
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!params.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
}
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_data data;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
auto has_python = false;
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
}
|
||||
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
// Note: if there's a python rule, it needs to come last.
|
||||
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
|
||||
if (has_python && params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"python\n", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false});
|
||||
}
|
||||
if (params.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
|
||||
} else {
|
||||
builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : ""));
|
||||
}
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
// TODO: handle tool {type: code_interpreter} as python
|
||||
common_chat_data data;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
auto has_python = false;
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
if (name == "python" || name == "ipython") {
|
||||
has_python = true;
|
||||
} else {
|
||||
auto parameters = function["parameters"];
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (has_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||
}
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ match[1].str(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
common_chat_data data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : params.tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
{"properties", json {
|
||||
{"name", json {{"const", name}}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
}
|
||||
|
||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||
}
|
||||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {"assistant", input, {}};
|
||||
}
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
call["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} catch (const std::exception & e) {
|
||||
return {"assistant", input, {}};
|
||||
}
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
if (params.tools.is_null()) {
|
||||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||
data.handler = std::make_unique<text_chat_parser>();
|
||||
return data;
|
||||
}
|
||||
const auto & src = tmpl.source();
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return build_hermes_2_pro_tool_call_handler(tmpl, params);
|
||||
}
|
||||
if (src.find(">>>all") != std::string::npos) {
|
||||
return build_functionary_v3_llama_3_tool_call_handler(tmpl, params);
|
||||
}
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return build_functionary_v3_llama_3_1_tool_call_handler(tmpl, params);
|
||||
}
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
||||
|
||||
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
||||
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
||||
// as it seems to be outputting some JSON.
|
||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||
|
||||
return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json);
|
||||
}
|
||||
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
// TODO: Command-R-Plus
|
||||
// }
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return build_mistral_nemo_tool_call_handler(tmpl, params);
|
||||
}
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
return build_firefunction_v2_tool_call_handler(tmpl, params);
|
||||
}
|
||||
return build_generic_tool_call_handler(tmpl, params);
|
||||
}
|
43
common/chat-handler.hpp
Normal file
43
common/chat-handler.hpp
Normal file
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct common_chat_params {
|
||||
json messages;
|
||||
json tools;
|
||||
json tool_choice;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool stream;
|
||||
};
|
||||
|
||||
class common_chat_parser {
|
||||
public:
|
||||
virtual ~common_chat_parser() = default;
|
||||
|
||||
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
|
||||
virtual common_chat_msg parse_final(const std::string & input) = 0;
|
||||
};
|
||||
|
||||
struct common_chat_data {
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::unique_ptr<class common_chat_parser> handler;
|
||||
};
|
||||
|
||||
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
|
@ -20,7 +20,7 @@ namespace minja {
|
|||
class chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
// private:
|
||||
bool supports_tools_ = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
|
@ -28,6 +28,7 @@ class chat_template {
|
|||
bool requires_typed_content_ = false;
|
||||
bool supports_system_role_ = true;
|
||||
bool supports_parallel_tool_calls_ = false;
|
||||
bool supports_code_interpreter_ = false;
|
||||
std::string source_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
|
@ -60,28 +61,7 @@ class chat_template {
|
|||
});
|
||||
supports_tools_ = source.find("tools") != std::string::npos;
|
||||
|
||||
auto renders_string_arguments =
|
||||
try_raw_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
if (!renders_string_arguments) {
|
||||
auto renders_object_arguments =
|
||||
requires_object_arguments_ =
|
||||
try_raw_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
|
@ -102,9 +82,27 @@ class chat_template {
|
|||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
requires_object_arguments_ = renders_object_arguments;
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos
|
||||
&& try_raw_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") == std::string::npos;
|
||||
|
||||
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
||||
|
||||
supports_system_role_ = try_raw_render({
|
||||
|
@ -114,6 +112,8 @@ class chat_template {
|
|||
|
||||
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
|
||||
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
|
||||
|
||||
supports_code_interpreter_ = source.find("code_interpreter") != std::string::npos;
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
|
@ -130,8 +130,45 @@ class chat_template {
|
|||
bool adjust_inputs = true) const
|
||||
{
|
||||
json actual_messages;
|
||||
json actual_tools;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
auto has_code_interpreter = false;
|
||||
for (const auto & tool : tools) {
|
||||
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
|
||||
has_code_interpreter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
|
||||
actual_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
|
||||
static const auto python_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ipython",
|
||||
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to run in the ipython interpreter."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
})");
|
||||
actual_tools.push_back(python_tool);
|
||||
} else {
|
||||
actual_tools.push_back(tool);
|
||||
}
|
||||
}
|
||||
} else if (!tools.is_null()) {
|
||||
actual_tools = tools;
|
||||
}
|
||||
|
||||
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
|
||||
actual_messages = json::array();
|
||||
|
@ -173,7 +210,12 @@ class chat_template {
|
|||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
try {
|
||||
function["arguments"] = json::parse(arguments);
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
function["arguments"] = arguments;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -242,6 +284,9 @@ class chat_template {
|
|||
} else {
|
||||
actual_messages = messages;
|
||||
}
|
||||
// if (adjust_inputs) {
|
||||
// fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str());
|
||||
// }
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
|
@ -251,8 +296,12 @@ class chat_template {
|
|||
}));
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
auto tools_val = minja::Value(actual_tools);
|
||||
context->set("tools", tools_val);
|
||||
if (has_code_interpreter) {
|
||||
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
||||
context->set("builtin_tools", builtin_tools_val);
|
||||
}
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
|
|
|
@ -109,6 +109,11 @@ enum common_conversation_mode {
|
|||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
std::string word;
|
||||
bool at_start;
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
|
@ -155,7 +160,7 @@ struct common_params_sampling {
|
|||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
||||
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to enable grammar
|
||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
|
|
@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.c_str());
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
}
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
|
|
|
@ -1,747 +0,0 @@
|
|||
#include "tool-call.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static json normalize_tools(const json & tools) {
|
||||
static const auto python_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python",
|
||||
"description": "Runs code in an Python interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to run in the Python interpreter."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
})");
|
||||
|
||||
auto results = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
results.push_back(python_tool);
|
||||
} else if (tool["type"] == "function") {
|
||||
results.push_back(tool);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string common_tool_call_style_name(common_tool_call_style style) {
|
||||
switch (style) {
|
||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
||||
return "None";
|
||||
case COMMON_TOOL_CALL_STYLE_GENERIC:
|
||||
return "Generic";
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
||||
return "Llama-3.1";
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
|
||||
return "Llama-3.2";
|
||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
|
||||
return "FunctionaryV3Llama3";
|
||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
|
||||
return "FunctionaryV3Llama3.1";
|
||||
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
|
||||
return "Hermes2Pro";
|
||||
case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS:
|
||||
return "CommandRPlus";
|
||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
|
||||
return "MistralNemo";
|
||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
|
||||
return "FirefunctionV2";
|
||||
default:
|
||||
return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template) {
|
||||
const auto & src = chat_template.source();
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO;
|
||||
} else if (src.find(">>>all") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3;
|
||||
} else if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1;
|
||||
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
if (src.find("<|python_tag|>") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
|
||||
} else {
|
||||
return COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||
}
|
||||
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS;
|
||||
} else if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO;
|
||||
} else if (src.find(" functools[") != std::string::npos) {
|
||||
return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2;
|
||||
} else {
|
||||
return COMMON_TOOL_CALL_STYLE_GENERIC;
|
||||
}
|
||||
}
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
return false;
|
||||
}
|
||||
bool null() override { return true; }
|
||||
bool boolean(bool) override { return true; }
|
||||
bool number_integer(number_integer_t) override { return true; }
|
||||
bool number_unsigned(number_unsigned_t) override { return true; }
|
||||
bool number_float(number_float_t, const string_t &) override { return true; }
|
||||
bool string(string_t &) override { return true; }
|
||||
bool binary(binary_t &) override { return true; }
|
||||
bool start_object(std::size_t) override { return true; }
|
||||
bool key(string_t &) override { return true; }
|
||||
bool end_object() override { return true; }
|
||||
bool start_array(std::size_t) override { return true; }
|
||||
bool end_array() override { return true; }
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
std::string::const_iterator temptative_end;
|
||||
if (err_loc.found_error) {
|
||||
temptative_end = it + err_loc.position;
|
||||
} else {
|
||||
temptative_end = end;
|
||||
}
|
||||
std::string json_sub {it, temptative_end};
|
||||
try {
|
||||
out = json::parse(json_sub);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
auto end = input.end();
|
||||
auto it = input.begin();
|
||||
|
||||
std::unordered_set<std::string> tool_names;
|
||||
if (check_names) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
std::string type = tool.at("type");
|
||||
if (type == "function") {
|
||||
tool_names.insert(tool["function"]["name"]);
|
||||
} else if (type == "code_interpreter") {
|
||||
tool_names.insert("python");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (it != end) {
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(it, end, function_regex);
|
||||
if (rit == rend) {
|
||||
result.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
auto name = rit->str(1);
|
||||
if (check_names && tool_names.find(name) == tool_names.end()) {
|
||||
result.content += std::string(it, rit->suffix().first);
|
||||
break;
|
||||
}
|
||||
|
||||
result.content += std::string(it, rit->prefix().second);
|
||||
it = rit->suffix().first;
|
||||
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_msg parse_hermes_tool_calls(const std::string& input) {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {"assistant", input, {}};
|
||||
}
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
call["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} catch (const std::exception & e) {
|
||||
return {"assistant", input, {}};
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
|
||||
if (allow_python_tag) {
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ (json {{"code", match[1].str()}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
}
|
||||
|
||||
static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ (json {{"code", match[1].str()}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
|
||||
}
|
||||
|
||||
static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
}
|
||||
|
||||
static common_chat_msg parse_generic_tool_calls(const std::string& input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
auto content_end = input.find(prefix);
|
||||
size_t tc_start = std::string::npos;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
if (content_end == std::string::npos) {
|
||||
result.content = input;
|
||||
} else {
|
||||
tc_start = content_end + prefix.size() - rstrip_prefix;
|
||||
result.content = input.substr(0, content_end);
|
||||
auto tool_calls = json::parse(input.substr(tc_start));
|
||||
process_tool_calls(tool_calls);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
}
|
||||
|
||||
static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
|
||||
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
|
||||
switch (style) {
|
||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
||||
return {"assistant", input, {}};
|
||||
case COMMON_TOOL_CALL_STYLE_GENERIC:
|
||||
return parse_generic_tool_calls(input);
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
||||
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
|
||||
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false);
|
||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
|
||||
return parse_functionary_v3_tool_calls(tools, input);
|
||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
|
||||
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
|
||||
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
|
||||
return parse_hermes_tool_calls(input);
|
||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
|
||||
return parse_mistral_nemo_tool_calls(input);
|
||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
|
||||
return parse_firefunction_v2_tool_calls(input);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported tool call style");
|
||||
}
|
||||
}
|
||||
|
||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||
json messages_with_system = messages;
|
||||
|
||||
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||
std::string existing_system = messages_with_system.at(0).at("content");
|
||||
messages_with_system[0] = json {
|
||||
{"role", "system"},
|
||||
{"content", existing_system + "\n" + system_prompt},
|
||||
};
|
||||
} else {
|
||||
messages_with_system.insert(messages_with_system.begin(), json {
|
||||
{"role", "system"},
|
||||
{"content", system_prompt},
|
||||
});
|
||||
}
|
||||
return messages_with_system;
|
||||
}
|
||||
|
||||
common_tool_call_handler common_tool_call_handler_init(
|
||||
common_tool_call_style style,
|
||||
const common_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
const nlohmann::ordered_json & parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
const nlohmann::ordered_json & json_schema)
|
||||
{
|
||||
common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ true,
|
||||
};
|
||||
common_tool_call_handler handler;
|
||||
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
|
||||
|
||||
switch (style) {
|
||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
case COMMON_TOOL_CALL_STYLE_GENERIC: {
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
auto tool_call_schemas = json::array();
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
auto tool_schema = json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
};
|
||||
if (function.contains("description")) {
|
||||
tool_schema["description"] = function["description"];
|
||||
}
|
||||
if (parallel) {
|
||||
tool_schema["properties"]["id"] = {
|
||||
{"type", "string"},
|
||||
{"minLength", 4},
|
||||
};
|
||||
tool_schema["required"].push_back("id");
|
||||
}
|
||||
tool_call_schemas.emplace_back(tool_schema);
|
||||
}
|
||||
const auto tool_call =
|
||||
parallel
|
||||
? json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_calls", {
|
||||
{"type", "array"},
|
||||
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||
{"anyOf", tool_call_schemas},
|
||||
}},
|
||||
{"minItems", 1},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_calls"})},
|
||||
}
|
||||
: json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||
{"anyOf", tool_call_schemas},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_call"})},
|
||||
};
|
||||
const auto schema =
|
||||
allow_content
|
||||
? json {
|
||||
{"anyOf", json::array({
|
||||
tool_call,
|
||||
{
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"response", json_schema.is_null()
|
||||
? json {{"type", "string"}}
|
||||
: json_schema
|
||||
},
|
||||
}},
|
||||
{"required", json::array({"response"})},
|
||||
},
|
||||
})}
|
||||
}
|
||||
: tool_call;
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_schema("root", schema);
|
||||
}, grammar_options);
|
||||
// TODO: add schema to system prompt.
|
||||
auto tweaked_messages = add_system(
|
||||
messages,
|
||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||
handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
||||
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
{"id", {
|
||||
{"type", "string"},
|
||||
// Nemo's template expects a 9-character alphanumeric ID.
|
||||
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
}
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!parallel) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("[TOOL_CALLS]");
|
||||
}
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function["name"]},
|
||||
}},
|
||||
{"arguments", function["parameters"]},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
}
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!parallel) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back(" functools[");
|
||||
}
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: {
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
builtin_tools.push_back("code_interpreter");
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
|
||||
auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
|
||||
|
||||
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
||||
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
||||
// as it seems to be outputting some JSON.
|
||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
if (uses_python_tag && (name == "ipython" || builtin_tools.contains(name))) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("<|python_tag|>");
|
||||
}
|
||||
} else {
|
||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (allow_content && !eagerly_match_any_json) {
|
||||
handler.grammar_triggers.push_back("{\"name\": \"" + name + "\"");
|
||||
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
|
||||
// Note that c++11's regex doesn't support partial matches, otherwise it would make
|
||||
// sense to add support for trigger regexes to the antiprompt mechanism.
|
||||
handler.grammar_triggers.push_back("{\n\t\"name\": \"" + name + "\"");
|
||||
handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\"");
|
||||
handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\"");
|
||||
handler.grammar_triggers.push_back("{\"type\": \"function\", \"name\": \"" + name + "\"");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (allow_content && eagerly_match_any_json) {
|
||||
handler.grammar_triggers.push_back("{\"");
|
||||
handler.grammar_triggers.push_back("{\n\t\"");
|
||||
handler.grammar_triggers.push_back("{\n \"");
|
||||
handler.grammar_triggers.push_back("{\n \"");
|
||||
}
|
||||
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
handler.additional_stops.push_back("<|eom_id|>");
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
break;
|
||||
}
|
||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back(name + "\n");
|
||||
handler.grammar_triggers.push_back("\n>>>" + name + "\n");
|
||||
}
|
||||
}
|
||||
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (parallel) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
||||
} else {
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
}, grammar_options);
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
}
|
||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: {
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
// TODO: handle tool {type: code_interpreter} as python
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
if (name == "python" || name == "ipython") {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("<|python_tag|>");
|
||||
}
|
||||
} else {
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
}
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("<function=");
|
||||
}
|
||||
}, grammar_options);
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
}
|
||||
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: {
|
||||
// NousResearchHermesPro_2
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
{"properties", json {
|
||||
{"name", json {{"const", name}}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
}
|
||||
|
||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("<tool_call>");
|
||||
}
|
||||
}, grammar_options);
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unsupported tool call style");
|
||||
}
|
||||
return handler;
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
#include "chat-template.hpp"
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
||||
enum common_tool_call_style {
|
||||
COMMON_TOOL_CALL_STYLE_UNKNOWN,
|
||||
COMMON_TOOL_CALL_STYLE_NONE,
|
||||
COMMON_TOOL_CALL_STYLE_GENERIC,
|
||||
COMMON_TOOL_CALL_STYLE_LLAMA_3_1,
|
||||
COMMON_TOOL_CALL_STYLE_LLAMA_3_2,
|
||||
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3,
|
||||
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1,
|
||||
COMMON_TOOL_CALL_STYLE_HERMES_2_PRO,
|
||||
COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS,
|
||||
COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO,
|
||||
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
|
||||
};
|
||||
|
||||
struct common_tool_call_handler {
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
std::vector<std::string> grammar_triggers;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
std::string common_tool_call_style_name(common_tool_call_style style);
|
||||
|
||||
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
|
||||
|
||||
common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||
|
||||
common_tool_call_handler common_tool_call_handler_init(
|
||||
common_tool_call_style style,
|
||||
const common_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
const nlohmann::ordered_json & parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
const nlohmann::ordered_json & json_schema = {});
|
|
@ -117,8 +117,7 @@ struct slot_params {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
json oaicompat_tools;
|
||||
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
|
||||
std::shared_ptr<common_chat_parser> chat_parser;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
|
@ -166,7 +165,7 @@ struct slot_params {
|
|||
{"n_probs", sampling.n_probs},
|
||||
{"min_keep", sampling.min_keep},
|
||||
{"grammar", sampling.grammar},
|
||||
{"grammar_trigger_words", sampling.grammar_trigger_words},
|
||||
// {"grammar_trigger_words", sampling.grammar_trigger_words},
|
||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
|
@ -212,6 +211,7 @@ struct server_task {
|
|||
static slot_params params_from_json_cmpl(
|
||||
const llama_context * ctx,
|
||||
const common_params & params_base,
|
||||
const common_chat_template * tmpl,
|
||||
const json & data) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
@ -322,6 +322,41 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
const auto stop = data.find("stop");
|
||||
if (stop != data.end()) {
|
||||
params.antiprompt = *stop;
|
||||
}
|
||||
}
|
||||
|
||||
if (tmpl && params_base.use_jinja) {
|
||||
common_chat_params chat_params;
|
||||
chat_params.messages = json_value(data, "messages", json::array());
|
||||
chat_params.tools = json_value(data, "tools", json());
|
||||
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
|
||||
chat_params.json_schema = json_value(data, "json_schema", json());
|
||||
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
|
||||
chat_params.stream = json_value(data, "stream", false);
|
||||
|
||||
auto chat_data = common_chat_init(*tmpl, chat_params);
|
||||
params.chat_parser = std::move(chat_data.handler);
|
||||
params.sampling.grammar = chat_data.grammar;
|
||||
for (const auto & stop : chat_data.additional_stops) {
|
||||
params.antiprompt.push_back(stop);
|
||||
}
|
||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
|
@ -336,13 +371,7 @@ struct server_task {
|
|||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
}
|
||||
|
||||
if (data.contains("tools")) {
|
||||
params.oaicompat_tools = data.at("tools");
|
||||
}
|
||||
if (data.contains("tool_call_style")) {
|
||||
params.oaicompat_tool_call_style = data.at("tool_call_style");
|
||||
}
|
||||
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
|
@ -379,45 +408,11 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
auto to_string_vec = [](const json & j) {
|
||||
std::vector<std::string> out;
|
||||
if (j.is_array()) {
|
||||
for (const auto & e : j) {
|
||||
if (e.is_string()) {
|
||||
out.push_back(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
const auto stop = data.find("stop");
|
||||
if (stop != data.end()) {
|
||||
params.antiprompt = to_string_vec(*stop);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto grammar_trigger_words = data.find("grammar_trigger_words");
|
||||
if (grammar_trigger_words != data.end()) {
|
||||
for (const auto & word : to_string_vec(*grammar_trigger_words)) {
|
||||
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
params.sampling.grammar_trigger_words.push_back(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto samplers = data.find("samplers");
|
||||
if (samplers != data.end()) {
|
||||
if (samplers->is_array()) {
|
||||
params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false);
|
||||
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
|
||||
} else if (samplers->is_string()){
|
||||
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
||||
}
|
||||
|
@ -592,8 +587,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
json oaicompat_tools;
|
||||
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
|
||||
common_chat_msg oaicompat_chat_msg;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -688,18 +682,13 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
finish_reason = oaicompat_chat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json tool_calls;
|
||||
json message_content;
|
||||
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
|
||||
auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
|
||||
if (!parsed_tool_calls.tool_calls.empty()) {
|
||||
finish_reason = "tool_calls";
|
||||
message_content = parsed_tool_calls.content;
|
||||
if (!oaicompat_chat_msg.tool_calls.empty()) {
|
||||
tool_calls = json::array();
|
||||
for (const auto & tc : parsed_tool_calls.tool_calls) {
|
||||
for (const auto & tc : oaicompat_chat_msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
|
@ -709,18 +698,13 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
{"id", tc.id.empty() ? json() : json(tc.id)},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
message_content = parsed_tool_calls.content;
|
||||
}
|
||||
} else {
|
||||
message_content = content;
|
||||
}
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json {
|
||||
{"content", message_content},
|
||||
{"content", oaicompat_chat_msg.content},
|
||||
{"tool_calls", tool_calls},
|
||||
{"role", "assistant"},
|
||||
}},
|
||||
|
@ -812,6 +796,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_chat_msg;
|
||||
std::shared_ptr<common_chat_parser> chat_parser;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -1234,6 +1220,8 @@ struct server_slot {
|
|||
|
||||
std::string stopping_word;
|
||||
|
||||
std::shared_ptr<common_chat_parser> chat_parser;
|
||||
|
||||
// sampling
|
||||
json json_schema;
|
||||
|
||||
|
@ -2260,6 +2248,10 @@ struct server_context {
|
|||
}
|
||||
|
||||
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
||||
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
|
||||
if (!opt_msg) {
|
||||
return;
|
||||
}
|
||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||
|
||||
res->id = slot.id_task;
|
||||
|
@ -2275,6 +2267,7 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_msg = *opt_msg;
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -2315,8 +2308,7 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_tools = slot.params.oaicompat_tools;
|
||||
res->oaicompat_tool_call_style = slot.params.oaicompat_tool_call_style;
|
||||
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -3776,7 +3768,7 @@ int main(int argc, char ** argv) {
|
|||
std::function<bool()> is_connection_closed,
|
||||
httplib::Response & res,
|
||||
oaicompat_type oaicompat,
|
||||
common_tool_call_style tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE) {
|
||||
const common_chat_template * tmpl) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
if (ctx_server.params_base.embedding) {
|
||||
|
@ -3788,6 +3780,20 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
|
||||
std::string prompt;
|
||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||
auto chat_data = common_chat_init(*tmpl, {
|
||||
/* .messages = */ json_data(data, "messages", json::array()),
|
||||
/* .tools = */ json_data(data, "tools", json()),
|
||||
/
|
||||
});
|
||||
|
||||
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
|
||||
} else {
|
||||
prompt = data.at("prompt").get<std::string>();
|
||||
}
|
||||
task.params.chat_parser = common_chat_init()
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
|
@ -3800,12 +3806,14 @@ int main(int argc, char ** argv) {
|
|||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
nullptr,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.chat_parser = common_chat_init()
|
||||
task.params.oaicompat_tools = json_value(data, "tools", json());
|
||||
task.params.oaicompat_tool_call_style = tool_call_style;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
@ -3983,18 +3991,16 @@ int main(int argc, char ** argv) {
|
|||
|
||||
auto body = json::parse(req.body);
|
||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
auto tool_call_style = common_tool_call_style_detect(chat_template);
|
||||
LOG_INF("Tool call style: %s\n", common_tool_call_style_name(tool_call_style).c_str());
|
||||
LOG_INF("Request: %s\n", body.dump(2).c_str());
|
||||
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, tool_call_style, params.use_jinja);
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
tool_call_style);
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "chat-handler.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "tool-call.h"
|
||||
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
|
@ -581,24 +581,18 @@ static json oaicompat_completion_params_parse(const json & body) {
|
|||
static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
const common_chat_template & tmpl,
|
||||
common_tool_call_style tool_call_style,
|
||||
bool use_jinja)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
auto stream = json_value(body, "stream", false);
|
||||
|
||||
if (has_tools) {
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
if (stream) {
|
||||
throw std::runtime_error("Cannot use tools with stream");
|
||||
}
|
||||
if (use_jinja) {
|
||||
if (tool_call_style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_UNKNOWN) {
|
||||
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
||||
}
|
||||
} else {
|
||||
if (!use_jinja) {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
|
@ -627,32 +621,16 @@ static json oaicompat_completion_params_parse(
|
|||
|
||||
// Apply chat template to the list of messages
|
||||
if (use_jinja) {
|
||||
bool allow_content = tool_choice != "required";
|
||||
if (tool_choice != "none" && has_tools) {
|
||||
llama_params["tools"] = tools;
|
||||
llama_params["tool_call_style"] = tool_call_style;
|
||||
|
||||
auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json();
|
||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||
|
||||
auto handler = common_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
|
||||
llama_params["prompt"] = handler.prompt;
|
||||
|
||||
for (const auto & stop : handler.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||
}
|
||||
if (!handler.grammar_triggers.empty()) {
|
||||
llama_params["grammar_trigger_words"] = handler.grammar_triggers;
|
||||
}
|
||||
if (!handler.grammar.empty()) {
|
||||
if (llama_params.contains("grammar")) {
|
||||
llama_params["tool_choice"] = tool_choice;
|
||||
llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false);
|
||||
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
llama_params["grammar"] = handler.grammar;
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||
}
|
||||
|
|
|
@ -133,7 +133,7 @@ llama_target_and_test(test-chat-template.cpp)
|
|||
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||
llama_target_and_test(test-gguf.cpp)
|
||||
llama_target_and_test(test-backend-ops.cpp)
|
||||
llama_target_and_test(test-tool-call.cpp)
|
||||
llama_target_and_test(test-chat-handler.cpp)
|
||||
|
||||
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||
llama_target_and_test(test-autorelease.cpp LABEL "model")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#include "tool-call.h"
|
||||
#include "chat-handler.hpp"
|
||||
#include "llama-grammar.h"
|
||||
#include "unicode.h"
|
||||
|
||||
|
@ -20,6 +20,7 @@ static void assert_equals(const T & expected, const T & actual) {
|
|||
}
|
||||
|
||||
static std::string read_file(const std::string &path) {
|
||||
std::cout << "# Reading: " << path << std::endl << std::flush;
|
||||
std::ifstream fs(path, std::ios_base::binary);
|
||||
if (!fs.is_open()) {
|
||||
fs = std::ifstream("../" + path, std::ios_base::binary);
|
||||
|
@ -76,10 +77,16 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool
|
|||
assert_equals(expected_content, result.content);
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : result.tool_calls) {
|
||||
auto arguments = tc.arguments;
|
||||
try {
|
||||
arguments = dump(json::parse(arguments));
|
||||
} catch (const std::exception & e) {
|
||||
// ignore
|
||||
}
|
||||
auto tool_call = json {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", dump(json::parse(tc.arguments))},
|
||||
{"arguments", arguments},
|
||||
{"name", tc.name},
|
||||
}},
|
||||
};
|
||||
|
@ -94,8 +101,7 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool
|
|||
assert_equals(expected, actual);
|
||||
}
|
||||
|
||||
const json tools = json::parse(R"([
|
||||
{
|
||||
const auto special_function_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "special_function",
|
||||
|
@ -111,25 +117,28 @@ const json tools = json::parse(R"([
|
|||
"required": ["arg1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
})");
|
||||
const auto python_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python",
|
||||
"description": "a python interpreter",
|
||||
"description": "an ipython interpreter",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code."
|
||||
"description": "Python code to execute."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
}
|
||||
])");
|
||||
})");
|
||||
const auto code_interpreter_tool = json::parse(R"({
|
||||
"type": "code_interpreter"
|
||||
})");
|
||||
const json tools = {special_function_tool, code_interpreter_tool};
|
||||
|
||||
static void test_parsing() {
|
||||
json request = {
|
||||
|
@ -226,9 +235,7 @@ static void test_parsing() {
|
|||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "python"},
|
||||
{"arguments", dump({
|
||||
{"code", "this could be anything"}
|
||||
})}
|
||||
{"arguments", "this could be anything"},
|
||||
}}
|
||||
}});
|
||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||
|
@ -238,7 +245,7 @@ static void test_parsing() {
|
|||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "python"},
|
||||
{"arguments", dump({{"code", ""}})}
|
||||
{"arguments", ""},
|
||||
}}
|
||||
}});
|
||||
auto special_function_call = json {
|
||||
|
@ -332,6 +339,8 @@ static void test_tool_call_style_detection() {
|
|||
}
|
||||
|
||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
||||
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
||||
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
||||
|
||||
|
@ -354,9 +363,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
|||
return delta;
|
||||
}
|
||||
|
||||
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
||||
const common_chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||
auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||
|
||||
|
@ -404,24 +411,77 @@ static void test_grammars() {
|
|||
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
|
||||
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
|
||||
|
||||
test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools,
|
||||
/* skip_grammar_test= */ true);
|
||||
test_template("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "<s>", "</s>", { "<|eot_id|>" }, tool_call_message, tools);
|
||||
test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "<s>", "</s>", { "<end_of_turn>" }, tool_call_message_with_id, tools);
|
||||
test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "<s>", "</s>", { "<|end|>" }, tool_call_message_with_id, tools);
|
||||
auto python_tool_call_message = json {
|
||||
{"role", "assistant"},
|
||||
{"content", ""},
|
||||
{"tool_calls", json {{
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "python"},
|
||||
{"arguments", "print('hey')"}
|
||||
}},
|
||||
}}}
|
||||
};
|
||||
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||
// assert_equals(tmpl.requires_object_arguments_, true);
|
||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<end_of_turn>" }, tool_call_message_with_id, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_tool_call_style_detection();
|
||||
test_parsing();
|
||||
// test_tool_call_style_detection();
|
||||
// test_parsing();
|
||||
test_grammars();
|
||||
|
||||
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
Loading…
Add table
Add a link
Reference in a new issue