WIP chat handlers
This commit is contained in:
parent
46415d7a51
commit
36ed106f84
13 changed files with 1071 additions and 999 deletions
3
Makefile
3
Makefile
|
@ -1363,10 +1363,11 @@ llama-server: \
|
||||||
examples/server/httplib.h \
|
examples/server/httplib.h \
|
||||||
examples/server/index.html.hpp \
|
examples/server/index.html.hpp \
|
||||||
examples/server/loading.html.hpp \
|
examples/server/loading.html.hpp \
|
||||||
|
common/chat-handler.cpp \
|
||||||
|
common/chat-handler.hpp \
|
||||||
common/chat-template.hpp \
|
common/chat-template.hpp \
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
common/minja.hpp \
|
common/minja.hpp \
|
||||||
common/tool-call.h \
|
|
||||||
$(OBJ_ALL)
|
$(OBJ_ALL)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||||
|
|
|
@ -56,6 +56,8 @@ add_library(${TARGET} STATIC
|
||||||
arg.cpp
|
arg.cpp
|
||||||
arg.h
|
arg.h
|
||||||
base64.hpp
|
base64.hpp
|
||||||
|
chat-handler.cpp
|
||||||
|
chat-handler.hpp
|
||||||
chat-template.hpp
|
chat-template.hpp
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
|
@ -72,7 +74,6 @@ add_library(${TARGET} STATIC
|
||||||
sampling.h
|
sampling.h
|
||||||
speculative.cpp
|
speculative.cpp
|
||||||
speculative.h
|
speculative.h
|
||||||
tool-call.cpp
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
if (BUILD_SHARED_LIBS)
|
||||||
|
|
720
common/chat-handler.cpp
Normal file
720
common/chat-handler.cpp
Normal file
|
@ -0,0 +1,720 @@
|
||||||
|
#include "chat-handler.hpp"
|
||||||
|
#include "chat-template.hpp"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "minja.hpp"
|
||||||
|
|
||||||
|
const common_grammar_options grammar_options {
|
||||||
|
/* .dotall = */ false,
|
||||||
|
/* .compact_spaces = */ false,
|
||||||
|
// /* .compact_spaces = */ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||||
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||||
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||||
|
std::size_t position;
|
||||||
|
bool found_error;
|
||||||
|
|
||||||
|
json_error_locator() : position(0), found_error(false) {}
|
||||||
|
|
||||||
|
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
||||||
|
this->position = position - 1;
|
||||||
|
this->found_error = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool null() override { return true; }
|
||||||
|
bool boolean(bool) override { return true; }
|
||||||
|
bool number_integer(number_integer_t) override { return true; }
|
||||||
|
bool number_unsigned(number_unsigned_t) override { return true; }
|
||||||
|
bool number_float(number_float_t, const string_t &) override { return true; }
|
||||||
|
bool string(string_t &) override { return true; }
|
||||||
|
bool binary(binary_t &) override { return true; }
|
||||||
|
bool start_object(std::size_t) override { return true; }
|
||||||
|
bool key(string_t &) override { return true; }
|
||||||
|
bool end_object() override { return true; }
|
||||||
|
bool start_array(std::size_t) override { return true; }
|
||||||
|
bool end_array() override { return true; }
|
||||||
|
};
|
||||||
|
json_error_locator err_loc;
|
||||||
|
json::sax_parse(it, end, &err_loc);
|
||||||
|
|
||||||
|
std::string::const_iterator temptative_end;
|
||||||
|
if (err_loc.found_error) {
|
||||||
|
temptative_end = it + err_loc.position;
|
||||||
|
} else {
|
||||||
|
temptative_end = end;
|
||||||
|
}
|
||||||
|
std::string json_sub {it, temptative_end};
|
||||||
|
try {
|
||||||
|
out = json::parse(json_sub);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||||
|
* Aggregates the prefix, suffix and in-between text into the content.
|
||||||
|
*/
|
||||||
|
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
||||||
|
std::smatch match;
|
||||||
|
|
||||||
|
common_chat_msg result;
|
||||||
|
result.role = "assistant";
|
||||||
|
auto end = input.end();
|
||||||
|
auto it = input.begin();
|
||||||
|
|
||||||
|
std::vector<std::string> tool_names;
|
||||||
|
if (check_names) {
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string type = tool.at("type");
|
||||||
|
if (type == "function") {
|
||||||
|
tool_names.push_back(tool["function"]["name"]);
|
||||||
|
} else if (type == "code_interpreter") {
|
||||||
|
tool_names.push_back("ipython");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (it != end) {
|
||||||
|
std::sregex_iterator rend;
|
||||||
|
std::sregex_iterator rit(it, end, function_regex);
|
||||||
|
if (rit == rend) {
|
||||||
|
fprintf(stderr, "No more tool calls found\n");
|
||||||
|
result.content += std::string(it, end);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto name = rit->str(1);
|
||||||
|
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
|
||||||
|
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
|
||||||
|
result.content += std::string(it, rit->suffix().first);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.content += std::string(it, rit->prefix().second);
|
||||||
|
it = rit->suffix().first;
|
||||||
|
|
||||||
|
|
||||||
|
json arguments;
|
||||||
|
if (!parse_json(it, end, arguments)) {
|
||||||
|
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||||
|
}
|
||||||
|
if (!std::regex_search(it, end, match, close_regex)) {
|
||||||
|
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||||
|
}
|
||||||
|
it = match.suffix().first;
|
||||||
|
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||||
|
auto content_end = input.find(prefix);
|
||||||
|
size_t tc_start = std::string::npos;
|
||||||
|
|
||||||
|
common_chat_msg result;
|
||||||
|
result.role = "assistant";
|
||||||
|
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||||
|
for (const auto & tool_call : tool_calls) {
|
||||||
|
const auto & arguments = tool_call["arguments"];
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
tool_call["name"],
|
||||||
|
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
|
tool_call.contains("id") ? tool_call["id"] : "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (content_end == std::string::npos) {
|
||||||
|
result.content = input;
|
||||||
|
} else {
|
||||||
|
tc_start = content_end + prefix.size() - rstrip_prefix;
|
||||||
|
result.content = input.substr(0, content_end);
|
||||||
|
auto tool_calls = json::parse(input.substr(tc_start));
|
||||||
|
process_tool_calls(tool_calls);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||||
|
json messages_with_system = messages;
|
||||||
|
|
||||||
|
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||||
|
std::string existing_system = messages_with_system.at(0).at("content");
|
||||||
|
messages_with_system[0] = json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", existing_system + "\n" + system_prompt},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
messages_with_system.insert(messages_with_system.begin(), json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_prompt},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return messages_with_system;
|
||||||
|
}
|
||||||
|
|
||||||
|
class text_chat_parser : public common_chat_parser {
|
||||||
|
public:
|
||||||
|
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||||
|
return parse_final(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg parse_final(const std::string & input) override {
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ input,
|
||||||
|
/* .tool_calls = */ {},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class monolithic_chat_parser : public common_chat_parser {
|
||||||
|
|
||||||
|
std::string input_buffer_;
|
||||||
|
std::function<common_chat_msg(const std::string & input)> parse_final_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
monolithic_chat_parser(const std::function<common_chat_msg(const std::string & input)> & parse_final) : parse_final_(parse_final) {}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||||
|
input_buffer_ += input;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg parse_final(const std::string & input) override {
|
||||||
|
input_buffer_ += input;
|
||||||
|
auto out = parse_final_(input_buffer_);
|
||||||
|
input_buffer_.clear();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
common_chat_data data;
|
||||||
|
|
||||||
|
auto tool_call_schemas = json::array();
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
auto tool_schema = json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"name", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"const", function["name"]},
|
||||||
|
}},
|
||||||
|
{"arguments", function["parameters"]},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments"})},
|
||||||
|
};
|
||||||
|
if (function.contains("description")) {
|
||||||
|
tool_schema["description"] = function["description"];
|
||||||
|
}
|
||||||
|
if (params.parallel_tool_calls) {
|
||||||
|
tool_schema["properties"]["id"] = {
|
||||||
|
{"type", "string"},
|
||||||
|
{"minLength", 4},
|
||||||
|
};
|
||||||
|
tool_schema["required"].push_back("id");
|
||||||
|
}
|
||||||
|
tool_call_schemas.emplace_back(tool_schema);
|
||||||
|
}
|
||||||
|
const auto tool_call =
|
||||||
|
params.parallel_tool_calls
|
||||||
|
? json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"tool_calls", {
|
||||||
|
{"type", "array"},
|
||||||
|
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||||
|
{"anyOf", tool_call_schemas},
|
||||||
|
}},
|
||||||
|
{"minItems", 1},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"tool_calls"})},
|
||||||
|
}
|
||||||
|
: json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
||||||
|
{"anyOf", tool_call_schemas},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"tool_call"})},
|
||||||
|
};
|
||||||
|
const auto schema =
|
||||||
|
params.tool_choice != "required"
|
||||||
|
? json {
|
||||||
|
{"anyOf", json::array({
|
||||||
|
tool_call,
|
||||||
|
{
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"response", params.json_schema.is_null()
|
||||||
|
? json {{"type", "string"}}
|
||||||
|
: params.json_schema
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"response"})},
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
: tool_call;
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
builder.add_schema("root", schema);
|
||||||
|
}, grammar_options);
|
||||||
|
|
||||||
|
// TODO: add schema to system prompt.
|
||||||
|
auto tweaked_messages = add_system(
|
||||||
|
params.messages,
|
||||||
|
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||||
|
|
||||||
|
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
||||||
|
json data = json::parse(input);
|
||||||
|
common_chat_msg result;
|
||||||
|
result.role = "assistant";
|
||||||
|
if (data.contains("tool_calls")) {
|
||||||
|
for (const auto & tool_call : data["tool_calls"]) {
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
tool_call["name"],
|
||||||
|
tool_call["arguments"].dump(),
|
||||||
|
tool_call.contains("id") ? tool_call["id"] : "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (data.contains("tool_call")) {
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
data["tool_call"]["name"],
|
||||||
|
data["tool_call"]["arguments"].dump(),
|
||||||
|
/* id= */ "",
|
||||||
|
});
|
||||||
|
} else if (data.contains("response")) {
|
||||||
|
const auto & response = data["response"];
|
||||||
|
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
common_chat_data data;
|
||||||
|
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
auto schemas = json::array();
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
schemas.push_back({
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
||||||
|
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
||||||
|
{"name", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"const", function["name"]},
|
||||||
|
}},
|
||||||
|
{"arguments", function["parameters"]},
|
||||||
|
{"id", {
|
||||||
|
{"type", "string"},
|
||||||
|
// Nemo's template expects a 9-character alphanumeric ID.
|
||||||
|
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments", "id"})},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
auto schema = json {
|
||||||
|
{"type", "array"},
|
||||||
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
|
{"minItems", 1},
|
||||||
|
};
|
||||||
|
if (!params.parallel_tool_calls) {
|
||||||
|
schema["maxItems"] = 1;
|
||||||
|
}
|
||||||
|
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||||
|
}, grammar_options);
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||||
|
}
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||||
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
||||||
|
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tool["type"] == "code_interpreter") {
|
||||||
|
builtin_tools.push_back("code_interpreter");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_data data;
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
|
auto has_python = false;
|
||||||
|
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tool["type"] == "code_interpreter") {
|
||||||
|
has_python = true;
|
||||||
|
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
|
||||||
|
has_python = true;
|
||||||
|
} else {
|
||||||
|
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||||
|
tool_rules.push_back(
|
||||||
|
builder.add_rule(
|
||||||
|
name + "-call",
|
||||||
|
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||||
|
builder.add_schema(name + "-args", parameters) +
|
||||||
|
" \"}\""));
|
||||||
|
if (params.tool_choice != "required" && !eagerly_match_any_json) {
|
||||||
|
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||||
|
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
|
||||||
|
// Note that c++11's regex doesn't support partial matches, otherwise it would make
|
||||||
|
// sense to add support for trigger regexes to the antiprompt mechanism.
|
||||||
|
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||||
|
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||||
|
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||||
|
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_python) {
|
||||||
|
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.tool_choice != "required" && eagerly_match_any_json) {
|
||||||
|
data.grammar_triggers.push_back({"{\"", /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
|
}, grammar_options);
|
||||||
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
|
||||||
|
if (uses_python_tag) {
|
||||||
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(input, match, python_tag_regex)) {
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ match.prefix().str(),
|
||||||
|
/* .tool_calls = */ {
|
||||||
|
{
|
||||||
|
/* .name = */ "python",
|
||||||
|
/* .arguments = */ match[1].str(),
|
||||||
|
/* .id = */ "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||||
|
static std::regex close_regex("\\}");
|
||||||
|
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
common_chat_data data;
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
auto schemas = json::array();
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
schemas.push_back({
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"name", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"const", function["name"]},
|
||||||
|
}},
|
||||||
|
{"arguments", function["parameters"]},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments", "id"})},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
auto schema = json {
|
||||||
|
{"type", "array"},
|
||||||
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
|
{"minItems", 1},
|
||||||
|
};
|
||||||
|
if (!params.parallel_tool_calls) {
|
||||||
|
schema["maxItems"] = 1;
|
||||||
|
}
|
||||||
|
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||||
|
}, grammar_options);
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||||
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
|
common_chat_data data;
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> first_tool_rules;
|
||||||
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
|
auto has_python = false;
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tool["type"] == "code_interpreter") {
|
||||||
|
has_python = true;
|
||||||
|
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||||
|
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||||
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||||
|
// Note: if there's a python rule, it needs to come last.
|
||||||
|
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
|
||||||
|
if (has_python && params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"python\n", /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
if (params.parallel_tool_calls) {
|
||||||
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||||
|
builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
|
||||||
|
} else {
|
||||||
|
builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : ""));
|
||||||
|
}
|
||||||
|
}, grammar_options);
|
||||||
|
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
|
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||||
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
|
// TODO: handle tool {type: code_interpreter} as python
|
||||||
|
common_chat_data data;
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
auto has_python = false;
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tool["type"] == "code_interpreter") {
|
||||||
|
has_python = true;
|
||||||
|
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
if (name == "python" || name == "ipython") {
|
||||||
|
has_python = true;
|
||||||
|
} else {
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (has_python) {
|
||||||
|
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||||
|
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
}, grammar_options);
|
||||||
|
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||||
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(input, match, python_tag_regex)) {
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ match.prefix().str(),
|
||||||
|
/* .tool_calls = */ {
|
||||||
|
{
|
||||||
|
/* .name = */ "python",
|
||||||
|
/* .arguments = */ match[1].str(),
|
||||||
|
/* .id = */ "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
|
static std::regex close_regex(R"(</function>)");
|
||||||
|
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false);
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
common_chat_data data;
|
||||||
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (const auto & tool : params.tools) {
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", json {
|
||||||
|
{"name", json {{"const", name}}},
|
||||||
|
{"arguments", parameters},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments"})},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||||
|
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
|
if (params.tool_choice != "required") {
|
||||||
|
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||||
|
}
|
||||||
|
}, grammar_options);
|
||||||
|
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
||||||
|
try {
|
||||||
|
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||||
|
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||||
|
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||||
|
|
||||||
|
auto end = input.end();
|
||||||
|
std::sregex_iterator rend;
|
||||||
|
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||||
|
if (rit == rend) {
|
||||||
|
return {"assistant", input, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg result;
|
||||||
|
result.role = "assistant";
|
||||||
|
result.content = rit->prefix();
|
||||||
|
|
||||||
|
auto it = rit->suffix().first;
|
||||||
|
while (it != end) {
|
||||||
|
json call;
|
||||||
|
if (!parse_json(it, end, call)) {
|
||||||
|
throw std::runtime_error("Failed to parse json tool call");
|
||||||
|
}
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
call["name"],
|
||||||
|
call["arguments"].dump(),
|
||||||
|
/* id= */ "",
|
||||||
|
});
|
||||||
|
rit = {it, end, middle_pattern};
|
||||||
|
if (rit != rend) {
|
||||||
|
it = rit->suffix().first;
|
||||||
|
} else {
|
||||||
|
rit = {it, end, end_pattern};
|
||||||
|
if (rit == rend) {
|
||||||
|
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
return {"assistant", input, {}};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
if (params.tools.is_null()) {
|
||||||
|
common_chat_data data;
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
|
data.handler = std::make_unique<text_chat_parser>();
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
const auto & src = tmpl.source();
|
||||||
|
|
||||||
|
if (src.find("<tool_call>") != std::string::npos) {
|
||||||
|
return build_hermes_2_pro_tool_call_handler(tmpl, params);
|
||||||
|
}
|
||||||
|
if (src.find(">>>all") != std::string::npos) {
|
||||||
|
return build_functionary_v3_llama_3_tool_call_handler(tmpl, params);
|
||||||
|
}
|
||||||
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
|
&& src.find("<function=") != std::string::npos) {
|
||||||
|
return build_functionary_v3_llama_3_1_tool_call_handler(tmpl, params);
|
||||||
|
}
|
||||||
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
|
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
||||||
|
|
||||||
|
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
||||||
|
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
||||||
|
// as it seems to be outputting some JSON.
|
||||||
|
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||||
|
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||||
|
|
||||||
|
return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json);
|
||||||
|
}
|
||||||
|
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||||
|
// TODO: Command-R-Plus
|
||||||
|
// }
|
||||||
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
|
return build_mistral_nemo_tool_call_handler(tmpl, params);
|
||||||
|
}
|
||||||
|
if (src.find(" functools[") != std::string::npos) {
|
||||||
|
return build_firefunction_v2_tool_call_handler(tmpl, params);
|
||||||
|
}
|
||||||
|
return build_generic_tool_call_handler(tmpl, params);
|
||||||
|
}
|
43
common/chat-handler.hpp
Normal file
43
common/chat-handler.hpp
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
Copyright 2024 Google LLC
|
||||||
|
|
||||||
|
Use of this source code is governed by an MIT-style
|
||||||
|
license that can be found in the LICENSE file or at
|
||||||
|
https://opensource.org/licenses/MIT.
|
||||||
|
*/
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include <json.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
struct common_chat_params {
|
||||||
|
json messages;
|
||||||
|
json tools;
|
||||||
|
json tool_choice;
|
||||||
|
json json_schema;
|
||||||
|
bool parallel_tool_calls;
|
||||||
|
bool stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_parser {
|
||||||
|
public:
|
||||||
|
virtual ~common_chat_parser() = default;
|
||||||
|
|
||||||
|
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
|
||||||
|
virtual common_chat_msg parse_final(const std::string & input) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_data {
|
||||||
|
std::string prompt;
|
||||||
|
std::string grammar;
|
||||||
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
|
std::vector<std::string> additional_stops;
|
||||||
|
std::unique_ptr<class common_chat_parser> handler;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
|
@ -20,7 +20,7 @@ namespace minja {
|
||||||
class chat_template {
|
class chat_template {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
private:
|
// private:
|
||||||
bool supports_tools_ = true;
|
bool supports_tools_ = true;
|
||||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||||
|
@ -28,6 +28,7 @@ class chat_template {
|
||||||
bool requires_typed_content_ = false;
|
bool requires_typed_content_ = false;
|
||||||
bool supports_system_role_ = true;
|
bool supports_system_role_ = true;
|
||||||
bool supports_parallel_tool_calls_ = false;
|
bool supports_parallel_tool_calls_ = false;
|
||||||
|
bool supports_code_interpreter_ = false;
|
||||||
std::string source_;
|
std::string source_;
|
||||||
std::string bos_token_;
|
std::string bos_token_;
|
||||||
std::string eos_token_;
|
std::string eos_token_;
|
||||||
|
@ -60,28 +61,7 @@ class chat_template {
|
||||||
});
|
});
|
||||||
supports_tools_ = source.find("tools") != std::string::npos;
|
supports_tools_ = source.find("tools") != std::string::npos;
|
||||||
|
|
||||||
auto renders_string_arguments =
|
requires_object_arguments_ =
|
||||||
try_raw_render({
|
|
||||||
{
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", "Hey"}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
{"role", "assistant"},
|
|
||||||
{"tool_calls", json::array({
|
|
||||||
{
|
|
||||||
{"id", "call_1___"},
|
|
||||||
{"type", "function"},
|
|
||||||
{"function", {
|
|
||||||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
|
||||||
{"name", "ipython"},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
})},
|
|
||||||
}
|
|
||||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
|
||||||
if (!renders_string_arguments) {
|
|
||||||
auto renders_object_arguments =
|
|
||||||
try_raw_render({
|
try_raw_render({
|
||||||
{
|
{
|
||||||
{"role", "user"},
|
{"role", "user"},
|
||||||
|
@ -102,9 +82,27 @@ class chat_template {
|
||||||
},
|
},
|
||||||
})},
|
})},
|
||||||
}
|
}
|
||||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
}, {}, false).find("{\"code\": \"print") != std::string::npos
|
||||||
requires_object_arguments_ = renders_object_arguments;
|
&& try_raw_render({
|
||||||
|
{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "Hey"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", json::array({
|
||||||
|
{
|
||||||
|
{"id", "call_1___"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||||
|
{"name", "ipython"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})},
|
||||||
}
|
}
|
||||||
|
}, {}, false).find("{\"code\": \"print") == std::string::npos;
|
||||||
|
|
||||||
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
||||||
|
|
||||||
supports_system_role_ = try_raw_render({
|
supports_system_role_ = try_raw_render({
|
||||||
|
@ -114,6 +112,8 @@ class chat_template {
|
||||||
|
|
||||||
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
|
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
|
||||||
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
|
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
|
||||||
|
|
||||||
|
supports_code_interpreter_ = source.find("code_interpreter") != std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & source() const { return source_; }
|
const std::string & source() const { return source_; }
|
||||||
|
@ -130,8 +130,45 @@ class chat_template {
|
||||||
bool adjust_inputs = true) const
|
bool adjust_inputs = true) const
|
||||||
{
|
{
|
||||||
json actual_messages;
|
json actual_messages;
|
||||||
|
json actual_tools;
|
||||||
|
|
||||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
auto has_code_interpreter = false;
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
|
||||||
|
has_code_interpreter = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
|
||||||
|
actual_tools = json::array();
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
|
||||||
|
static const auto python_tool = json::parse(R"({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "ipython",
|
||||||
|
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The code to run in the ipython interpreter."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})");
|
||||||
|
actual_tools.push_back(python_tool);
|
||||||
|
} else {
|
||||||
|
actual_tools.push_back(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (!tools.is_null()) {
|
||||||
|
actual_tools = tools;
|
||||||
|
}
|
||||||
|
|
||||||
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
|
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
|
||||||
actual_messages = json::array();
|
actual_messages = json::array();
|
||||||
|
@ -173,7 +210,12 @@ class chat_template {
|
||||||
if (tool_call["type"] == "function") {
|
if (tool_call["type"] == "function") {
|
||||||
auto & function = tool_call.at("function");
|
auto & function = tool_call.at("function");
|
||||||
std::string arguments = function.at("arguments");
|
std::string arguments = function.at("arguments");
|
||||||
|
try {
|
||||||
function["arguments"] = json::parse(arguments);
|
function["arguments"] = json::parse(arguments);
|
||||||
|
} catch (const std::exception & ecvt) {
|
||||||
|
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||||
|
function["arguments"] = arguments;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -242,6 +284,9 @@ class chat_template {
|
||||||
} else {
|
} else {
|
||||||
actual_messages = messages;
|
actual_messages = messages;
|
||||||
}
|
}
|
||||||
|
// if (adjust_inputs) {
|
||||||
|
// fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str());
|
||||||
|
// }
|
||||||
|
|
||||||
auto context = minja::Context::make(json({
|
auto context = minja::Context::make(json({
|
||||||
{"messages", actual_messages},
|
{"messages", actual_messages},
|
||||||
|
@ -251,8 +296,12 @@ class chat_template {
|
||||||
}));
|
}));
|
||||||
|
|
||||||
if (!tools.is_null()) {
|
if (!tools.is_null()) {
|
||||||
auto tools_val = minja::Value(tools);
|
auto tools_val = minja::Value(actual_tools);
|
||||||
context->set("tools", tools_val);
|
context->set("tools", tools_val);
|
||||||
|
if (has_code_interpreter) {
|
||||||
|
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
||||||
|
context->set("builtin_tools", builtin_tools_val);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!extra_context.is_null()) {
|
if (!extra_context.is_null()) {
|
||||||
for (auto & kv : extra_context.items()) {
|
for (auto & kv : extra_context.items()) {
|
||||||
|
|
|
@ -109,6 +109,11 @@ enum common_conversation_mode {
|
||||||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct common_grammar_trigger {
|
||||||
|
std::string word;
|
||||||
|
bool at_start;
|
||||||
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
|
@ -155,7 +160,7 @@ struct common_params_sampling {
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to enable grammar
|
||||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
|
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
|
||||||
|
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
|
|
@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
std::vector<const char *> trigger_words;
|
std::vector<const char *> trigger_words;
|
||||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||||
for (const auto & str : params.grammar_trigger_words) {
|
for (const auto & str : params.grammar_trigger_words) {
|
||||||
trigger_words.push_back(str.c_str());
|
trigger_words.push_back(str.word.c_str());
|
||||||
}
|
}
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
|
|
|
@ -1,747 +0,0 @@
|
||||||
#include "tool-call.h"
|
|
||||||
#include "json-schema-to-grammar.h"
|
|
||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
#include <map>
|
|
||||||
#include <regex>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
|
||||||
|
|
||||||
static json normalize_tools(const json & tools) {
|
|
||||||
static const auto python_tool = json::parse(R"({
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "python",
|
|
||||||
"description": "Runs code in an Python interpreter and returns the result of the execution after 60 seconds.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"code": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The code to run in the Python interpreter."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["code"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})");
|
|
||||||
|
|
||||||
auto results = json::array();
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
if (!tool.contains("type")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (tool["type"] == "code_interpreter") {
|
|
||||||
results.push_back(python_tool);
|
|
||||||
} else if (tool["type"] == "function") {
|
|
||||||
results.push_back(tool);
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_tool_call_style_name(common_tool_call_style style) {
|
|
||||||
switch (style) {
|
|
||||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
|
||||||
return "None";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_GENERIC:
|
|
||||||
return "Generic";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
|
||||||
return "Llama-3.1";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
|
|
||||||
return "Llama-3.2";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
|
|
||||||
return "FunctionaryV3Llama3";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
|
|
||||||
return "FunctionaryV3Llama3.1";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
|
|
||||||
return "Hermes2Pro";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS:
|
|
||||||
return "CommandRPlus";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
|
|
||||||
return "MistralNemo";
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
|
|
||||||
return "FirefunctionV2";
|
|
||||||
default:
|
|
||||||
return "Unknown";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template) {
|
|
||||||
const auto & src = chat_template.source();
|
|
||||||
|
|
||||||
if (src.find("<tool_call>") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO;
|
|
||||||
} else if (src.find(">>>all") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3;
|
|
||||||
} else if (src.find("<|start_header_id|>") != std::string::npos
|
|
||||||
&& src.find("<function=") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1;
|
|
||||||
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
|
||||||
if (src.find("<|python_tag|>") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
|
|
||||||
} else {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
|
||||||
}
|
|
||||||
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS;
|
|
||||||
} else if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO;
|
|
||||||
} else if (src.find(" functools[") != std::string::npos) {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2;
|
|
||||||
} else {
|
|
||||||
return COMMON_TOOL_CALL_STYLE_GENERIC;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
|
||||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
|
||||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
|
||||||
std::size_t position;
|
|
||||||
bool found_error;
|
|
||||||
|
|
||||||
json_error_locator() : position(0), found_error(false) {}
|
|
||||||
|
|
||||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
|
||||||
this->position = position - 1;
|
|
||||||
this->found_error = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool null() override { return true; }
|
|
||||||
bool boolean(bool) override { return true; }
|
|
||||||
bool number_integer(number_integer_t) override { return true; }
|
|
||||||
bool number_unsigned(number_unsigned_t) override { return true; }
|
|
||||||
bool number_float(number_float_t, const string_t &) override { return true; }
|
|
||||||
bool string(string_t &) override { return true; }
|
|
||||||
bool binary(binary_t &) override { return true; }
|
|
||||||
bool start_object(std::size_t) override { return true; }
|
|
||||||
bool key(string_t &) override { return true; }
|
|
||||||
bool end_object() override { return true; }
|
|
||||||
bool start_array(std::size_t) override { return true; }
|
|
||||||
bool end_array() override { return true; }
|
|
||||||
};
|
|
||||||
json_error_locator err_loc;
|
|
||||||
json::sax_parse(it, end, &err_loc);
|
|
||||||
|
|
||||||
std::string::const_iterator temptative_end;
|
|
||||||
if (err_loc.found_error) {
|
|
||||||
temptative_end = it + err_loc.position;
|
|
||||||
} else {
|
|
||||||
temptative_end = end;
|
|
||||||
}
|
|
||||||
std::string json_sub {it, temptative_end};
|
|
||||||
try {
|
|
||||||
out = json::parse(json_sub);
|
|
||||||
it = temptative_end;
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception &) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
|
||||||
* Aggregates the prefix, suffix and in-between text into the content.
|
|
||||||
*/
|
|
||||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
|
||||||
std::smatch match;
|
|
||||||
|
|
||||||
common_chat_msg result;
|
|
||||||
result.role = "assistant";
|
|
||||||
auto end = input.end();
|
|
||||||
auto it = input.begin();
|
|
||||||
|
|
||||||
std::unordered_set<std::string> tool_names;
|
|
||||||
if (check_names) {
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
if (!tool.contains("type")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string type = tool.at("type");
|
|
||||||
if (type == "function") {
|
|
||||||
tool_names.insert(tool["function"]["name"]);
|
|
||||||
} else if (type == "code_interpreter") {
|
|
||||||
tool_names.insert("python");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while (it != end) {
|
|
||||||
std::sregex_iterator rend;
|
|
||||||
std::sregex_iterator rit(it, end, function_regex);
|
|
||||||
if (rit == rend) {
|
|
||||||
result.content += std::string(it, end);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
auto name = rit->str(1);
|
|
||||||
if (check_names && tool_names.find(name) == tool_names.end()) {
|
|
||||||
result.content += std::string(it, rit->suffix().first);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.content += std::string(it, rit->prefix().second);
|
|
||||||
it = rit->suffix().first;
|
|
||||||
|
|
||||||
|
|
||||||
json arguments;
|
|
||||||
if (!parse_json(it, end, arguments)) {
|
|
||||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
|
||||||
}
|
|
||||||
if (!std::regex_search(it, end, match, close_regex)) {
|
|
||||||
throw std::runtime_error("Malformed input, missing closing pattern");
|
|
||||||
}
|
|
||||||
it = match.suffix().first;
|
|
||||||
result.tool_calls.push_back({name, arguments.dump(), /* id= */ ""});
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_hermes_tool_calls(const std::string& input) {
|
|
||||||
try {
|
|
||||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
|
||||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
|
||||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
|
||||||
|
|
||||||
auto end = input.end();
|
|
||||||
std::sregex_iterator rend;
|
|
||||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
|
||||||
if (rit == rend) {
|
|
||||||
return {"assistant", input, {}};
|
|
||||||
}
|
|
||||||
|
|
||||||
common_chat_msg result;
|
|
||||||
result.role = "assistant";
|
|
||||||
result.content = rit->prefix();
|
|
||||||
|
|
||||||
auto it = rit->suffix().first;
|
|
||||||
while (it != end) {
|
|
||||||
json call;
|
|
||||||
if (!parse_json(it, end, call)) {
|
|
||||||
throw std::runtime_error("Failed to parse json tool call");
|
|
||||||
}
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
call["name"],
|
|
||||||
call["arguments"].dump(),
|
|
||||||
/* id= */ "",
|
|
||||||
});
|
|
||||||
rit = {it, end, middle_pattern};
|
|
||||||
if (rit != rend) {
|
|
||||||
it = rit->suffix().first;
|
|
||||||
} else {
|
|
||||||
rit = {it, end, end_pattern};
|
|
||||||
if (rit == rend) {
|
|
||||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
return {"assistant", input, {}};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
|
|
||||||
if (allow_python_tag) {
|
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
|
||||||
std::smatch match;
|
|
||||||
if (std::regex_search(input, match, python_tag_regex)) {
|
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ match.prefix().str(),
|
|
||||||
/* .tool_calls = */ {
|
|
||||||
{
|
|
||||||
/* .name = */ "python",
|
|
||||||
/* .arguments = */ (json {{"code", match[1].str()}}).dump(),
|
|
||||||
/* .id = */ "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
|
||||||
static std::regex close_regex("\\}");
|
|
||||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
|
|
||||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
|
||||||
std::smatch match;
|
|
||||||
if (std::regex_search(input, match, python_tag_regex)) {
|
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ match.prefix().str(),
|
|
||||||
/* .tool_calls = */ {
|
|
||||||
{
|
|
||||||
/* .name = */ "python",
|
|
||||||
/* .arguments = */ (json {{"code", match[1].str()}}).dump(),
|
|
||||||
/* .id = */ "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
|
||||||
static std::regex close_regex(R"(</function>)");
|
|
||||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
|
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
|
||||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_generic_tool_calls(const std::string& input) {
|
|
||||||
json data = json::parse(input);
|
|
||||||
common_chat_msg result;
|
|
||||||
result.role = "assistant";
|
|
||||||
if (data.contains("tool_calls")) {
|
|
||||||
for (const auto & tool_call : data["tool_calls"]) {
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
tool_call["name"],
|
|
||||||
tool_call["arguments"].dump(),
|
|
||||||
tool_call.contains("id") ? tool_call["id"] : "",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (data.contains("tool_call")) {
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
data["tool_call"]["name"],
|
|
||||||
data["tool_call"]["arguments"].dump(),
|
|
||||||
/* id= */ "",
|
|
||||||
});
|
|
||||||
} else if (data.contains("response")) {
|
|
||||||
const auto & response = data["response"];
|
|
||||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
|
||||||
auto content_end = input.find(prefix);
|
|
||||||
size_t tc_start = std::string::npos;
|
|
||||||
|
|
||||||
common_chat_msg result;
|
|
||||||
result.role = "assistant";
|
|
||||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
|
||||||
for (const auto & tool_call : tool_calls) {
|
|
||||||
const auto & arguments = tool_call["arguments"];
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
tool_call["name"],
|
|
||||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
||||||
tool_call.contains("id") ? tool_call["id"] : "",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if (content_end == std::string::npos) {
|
|
||||||
result.content = input;
|
|
||||||
} else {
|
|
||||||
tc_start = content_end + prefix.size() - rstrip_prefix;
|
|
||||||
result.content = input.substr(0, content_end);
|
|
||||||
auto tool_calls = json::parse(input.substr(tc_start));
|
|
||||||
process_tool_calls(tool_calls);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) {
|
|
||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) {
|
|
||||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
|
|
||||||
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
|
|
||||||
switch (style) {
|
|
||||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
|
||||||
return {"assistant", input, {}};
|
|
||||||
case COMMON_TOOL_CALL_STYLE_GENERIC:
|
|
||||||
return parse_generic_tool_calls(input);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
|
||||||
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
|
|
||||||
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
|
|
||||||
return parse_functionary_v3_tool_calls(tools, input);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
|
|
||||||
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
|
|
||||||
return parse_hermes_tool_calls(input);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
|
|
||||||
return parse_mistral_nemo_tool_calls(input);
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
|
|
||||||
return parse_firefunction_v2_tool_calls(input);
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("Unsupported tool call style");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
|
||||||
json messages_with_system = messages;
|
|
||||||
|
|
||||||
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
|
||||||
std::string existing_system = messages_with_system.at(0).at("content");
|
|
||||||
messages_with_system[0] = json {
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", existing_system + "\n" + system_prompt},
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
messages_with_system.insert(messages_with_system.begin(), json {
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", system_prompt},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return messages_with_system;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_tool_call_handler common_tool_call_handler_init(
|
|
||||||
common_tool_call_style style,
|
|
||||||
const common_chat_template & tmpl,
|
|
||||||
bool allow_content,
|
|
||||||
const nlohmann::ordered_json & parallel_tool_calls,
|
|
||||||
const nlohmann::ordered_json & messages,
|
|
||||||
const nlohmann::ordered_json & tools,
|
|
||||||
const nlohmann::ordered_json & json_schema)
|
|
||||||
{
|
|
||||||
common_grammar_options grammar_options {
|
|
||||||
/* .dotall = */ false,
|
|
||||||
/* .compact_spaces = */ true,
|
|
||||||
};
|
|
||||||
common_tool_call_handler handler;
|
|
||||||
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
|
|
||||||
|
|
||||||
switch (style) {
|
|
||||||
case COMMON_TOOL_CALL_STYLE_NONE:
|
|
||||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
|
||||||
break;
|
|
||||||
case COMMON_TOOL_CALL_STYLE_GENERIC: {
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
auto tool_call_schemas = json::array();
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
auto tool_schema = json {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"name", {
|
|
||||||
{"type", "string"},
|
|
||||||
{"const", function["name"]},
|
|
||||||
}},
|
|
||||||
{"arguments", function["parameters"]},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"name", "arguments"})},
|
|
||||||
};
|
|
||||||
if (function.contains("description")) {
|
|
||||||
tool_schema["description"] = function["description"];
|
|
||||||
}
|
|
||||||
if (parallel) {
|
|
||||||
tool_schema["properties"]["id"] = {
|
|
||||||
{"type", "string"},
|
|
||||||
{"minLength", 4},
|
|
||||||
};
|
|
||||||
tool_schema["required"].push_back("id");
|
|
||||||
}
|
|
||||||
tool_call_schemas.emplace_back(tool_schema);
|
|
||||||
}
|
|
||||||
const auto tool_call =
|
|
||||||
parallel
|
|
||||||
? json {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"tool_calls", {
|
|
||||||
{"type", "array"},
|
|
||||||
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
|
||||||
{"anyOf", tool_call_schemas},
|
|
||||||
}},
|
|
||||||
{"minItems", 1},
|
|
||||||
}},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"tool_calls"})},
|
|
||||||
}
|
|
||||||
: json {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
|
||||||
{"anyOf", tool_call_schemas},
|
|
||||||
}},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"tool_call"})},
|
|
||||||
};
|
|
||||||
const auto schema =
|
|
||||||
allow_content
|
|
||||||
? json {
|
|
||||||
{"anyOf", json::array({
|
|
||||||
tool_call,
|
|
||||||
{
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"response", json_schema.is_null()
|
|
||||||
? json {{"type", "string"}}
|
|
||||||
: json_schema
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"response"})},
|
|
||||||
},
|
|
||||||
})}
|
|
||||||
}
|
|
||||||
: tool_call;
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
builder.add_schema("root", schema);
|
|
||||||
}, grammar_options);
|
|
||||||
// TODO: add schema to system prompt.
|
|
||||||
auto tweaked_messages = add_system(
|
|
||||||
messages,
|
|
||||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
|
||||||
handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
auto schemas = json::array();
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
schemas.push_back({
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
|
||||||
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
|
||||||
{"name", {
|
|
||||||
{"type", "string"},
|
|
||||||
{"const", function["name"]},
|
|
||||||
}},
|
|
||||||
{"arguments", function["parameters"]},
|
|
||||||
{"id", {
|
|
||||||
{"type", "string"},
|
|
||||||
// Nemo's template expects a 9-character alphanumeric ID.
|
|
||||||
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
|
||||||
}},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"name", "arguments", "id"})},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
auto schema = json {
|
|
||||||
{"type", "array"},
|
|
||||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
|
||||||
{"minItems", 1},
|
|
||||||
};
|
|
||||||
if (!parallel) {
|
|
||||||
schema["maxItems"] = 1;
|
|
||||||
}
|
|
||||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
|
||||||
}, grammar_options);
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back("[TOOL_CALLS]");
|
|
||||||
}
|
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
auto schemas = json::array();
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
schemas.push_back({
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", {
|
|
||||||
{"name", {
|
|
||||||
{"type", "string"},
|
|
||||||
{"const", function["name"]},
|
|
||||||
}},
|
|
||||||
{"arguments", function["parameters"]},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"name", "arguments", "id"})},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
auto schema = json {
|
|
||||||
{"type", "array"},
|
|
||||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
|
||||||
{"minItems", 1},
|
|
||||||
};
|
|
||||||
if (!parallel) {
|
|
||||||
schema["maxItems"] = 1;
|
|
||||||
}
|
|
||||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
|
||||||
}, grammar_options);
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back(" functools[");
|
|
||||||
}
|
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
|
|
||||||
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: {
|
|
||||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
if (!tool.contains("type")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (tool["type"] == "code_interpreter") {
|
|
||||||
builtin_tools.push_back("code_interpreter");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
|
|
||||||
auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
|
|
||||||
|
|
||||||
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
|
||||||
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
|
||||||
// as it seems to be outputting some JSON.
|
|
||||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
|
||||||
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
|
||||||
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
if (uses_python_tag && (name == "ipython" || builtin_tools.contains(name))) {
|
|
||||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back("<|python_tag|>");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
|
||||||
tool_rules.push_back(
|
|
||||||
builder.add_rule(
|
|
||||||
name + "-call",
|
|
||||||
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
|
||||||
builder.add_schema(name + "-args", parameters) +
|
|
||||||
" \"}\""));
|
|
||||||
if (allow_content && !eagerly_match_any_json) {
|
|
||||||
handler.grammar_triggers.push_back("{\"name\": \"" + name + "\"");
|
|
||||||
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
|
|
||||||
// Note that c++11's regex doesn't support partial matches, otherwise it would make
|
|
||||||
// sense to add support for trigger regexes to the antiprompt mechanism.
|
|
||||||
handler.grammar_triggers.push_back("{\n\t\"name\": \"" + name + "\"");
|
|
||||||
handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\"");
|
|
||||||
handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\"");
|
|
||||||
handler.grammar_triggers.push_back("{\"type\": \"function\", \"name\": \"" + name + "\"");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (allow_content && eagerly_match_any_json) {
|
|
||||||
handler.grammar_triggers.push_back("{\"");
|
|
||||||
handler.grammar_triggers.push_back("{\n\t\"");
|
|
||||||
handler.grammar_triggers.push_back("{\n \"");
|
|
||||||
handler.grammar_triggers.push_back("{\n \"");
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
|
||||||
}, grammar_options);
|
|
||||||
handler.additional_stops.push_back("<|eom_id|>");
|
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
|
|
||||||
{"builtin_tools", builtin_tools},
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: {
|
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
std::vector<std::string> first_tool_rules;
|
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
|
||||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
|
||||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back(name + "\n");
|
|
||||||
handler.grammar_triggers.push_back("\n>>>" + name + "\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
|
||||||
if (parallel) {
|
|
||||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
|
||||||
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
|
||||||
} else {
|
|
||||||
builder.add_rule("root", first_rule);
|
|
||||||
}
|
|
||||||
}, grammar_options);
|
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: {
|
|
||||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
|
||||||
// TODO: handle tool {type: code_interpreter} as python
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
if (name == "python" || name == "ipython") {
|
|
||||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back("<|python_tag|>");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
|
||||||
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back("<function=");
|
|
||||||
}
|
|
||||||
}, grammar_options);
|
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: {
|
|
||||||
// NousResearchHermesPro_2
|
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
|
||||||
auto actual_tools = normalize_tools(tools);
|
|
||||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
for (const auto & tool : actual_tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", json {
|
|
||||||
{"name", json {{"const", name}}},
|
|
||||||
{"arguments", parameters},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"name", "arguments"})},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
|
||||||
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_triggers.push_back("<tool_call>");
|
|
||||||
}
|
|
||||||
}, grammar_options);
|
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("Unsupported tool call style");
|
|
||||||
}
|
|
||||||
return handler;
|
|
||||||
}
|
|
|
@ -1,44 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "chat-template.hpp"
|
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
|
||||||
#include "json.hpp"
|
|
||||||
|
|
||||||
enum common_tool_call_style {
|
|
||||||
COMMON_TOOL_CALL_STYLE_UNKNOWN,
|
|
||||||
COMMON_TOOL_CALL_STYLE_NONE,
|
|
||||||
COMMON_TOOL_CALL_STYLE_GENERIC,
|
|
||||||
COMMON_TOOL_CALL_STYLE_LLAMA_3_1,
|
|
||||||
COMMON_TOOL_CALL_STYLE_LLAMA_3_2,
|
|
||||||
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3,
|
|
||||||
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1,
|
|
||||||
COMMON_TOOL_CALL_STYLE_HERMES_2_PRO,
|
|
||||||
COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS,
|
|
||||||
COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO,
|
|
||||||
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_tool_call_handler {
|
|
||||||
std::string prompt;
|
|
||||||
std::string grammar;
|
|
||||||
std::vector<std::string> grammar_triggers;
|
|
||||||
std::vector<std::string> additional_stops;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string common_tool_call_style_name(common_tool_call_style style);
|
|
||||||
|
|
||||||
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
|
|
||||||
|
|
||||||
common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
|
||||||
|
|
||||||
common_tool_call_handler common_tool_call_handler_init(
|
|
||||||
common_tool_call_style style,
|
|
||||||
const common_chat_template & tmpl,
|
|
||||||
bool allow_content,
|
|
||||||
const nlohmann::ordered_json & parallel_tool_calls,
|
|
||||||
const nlohmann::ordered_json & messages,
|
|
||||||
const nlohmann::ordered_json & tools,
|
|
||||||
const nlohmann::ordered_json & json_schema = {});
|
|
|
@ -117,8 +117,7 @@ struct slot_params {
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
json oaicompat_tools;
|
std::shared_ptr<common_chat_parser> chat_parser;
|
||||||
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
|
|
||||||
|
|
||||||
json to_json() const {
|
json to_json() const {
|
||||||
std::vector<std::string> samplers;
|
std::vector<std::string> samplers;
|
||||||
|
@ -166,7 +165,7 @@ struct slot_params {
|
||||||
{"n_probs", sampling.n_probs},
|
{"n_probs", sampling.n_probs},
|
||||||
{"min_keep", sampling.min_keep},
|
{"min_keep", sampling.min_keep},
|
||||||
{"grammar", sampling.grammar},
|
{"grammar", sampling.grammar},
|
||||||
{"grammar_trigger_words", sampling.grammar_trigger_words},
|
// {"grammar_trigger_words", sampling.grammar_trigger_words},
|
||||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||||
{"samplers", samplers},
|
{"samplers", samplers},
|
||||||
{"speculative.n_max", speculative.n_max},
|
{"speculative.n_max", speculative.n_max},
|
||||||
|
@ -212,6 +211,7 @@ struct server_task {
|
||||||
static slot_params params_from_json_cmpl(
|
static slot_params params_from_json_cmpl(
|
||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
|
const common_chat_template * tmpl,
|
||||||
const json & data) {
|
const json & data) {
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
@ -322,6 +322,41 @@ struct server_task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
params.antiprompt.clear();
|
||||||
|
const auto stop = data.find("stop");
|
||||||
|
if (stop != data.end()) {
|
||||||
|
params.antiprompt = *stop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tmpl && params_base.use_jinja) {
|
||||||
|
common_chat_params chat_params;
|
||||||
|
chat_params.messages = json_value(data, "messages", json::array());
|
||||||
|
chat_params.tools = json_value(data, "tools", json());
|
||||||
|
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
|
||||||
|
chat_params.json_schema = json_value(data, "json_schema", json());
|
||||||
|
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
|
||||||
|
chat_params.stream = json_value(data, "stream", false);
|
||||||
|
|
||||||
|
auto chat_data = common_chat_init(*tmpl, chat_params);
|
||||||
|
params.chat_parser = std::move(chat_data.handler);
|
||||||
|
params.sampling.grammar = chat_data.grammar;
|
||||||
|
for (const auto & stop : chat_data.additional_stops) {
|
||||||
|
params.antiprompt.push_back(stop);
|
||||||
|
}
|
||||||
|
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||||
|
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
||||||
|
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
||||||
|
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// process "json_schema" and "grammar"
|
// process "json_schema" and "grammar"
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||||
|
@ -336,13 +371,7 @@ struct server_task {
|
||||||
} else {
|
} else {
|
||||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||||
}
|
}
|
||||||
|
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||||
if (data.contains("tools")) {
|
|
||||||
params.oaicompat_tools = data.at("tools");
|
|
||||||
}
|
|
||||||
if (data.contains("tool_call_style")) {
|
|
||||||
params.oaicompat_tool_call_style = data.at("tool_call_style");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
params.sampling.logit_bias.clear();
|
params.sampling.logit_bias.clear();
|
||||||
|
@ -379,45 +408,11 @@ struct server_task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto to_string_vec = [](const json & j) {
|
|
||||||
std::vector<std::string> out;
|
|
||||||
if (j.is_array()) {
|
|
||||||
for (const auto & e : j) {
|
|
||||||
if (e.is_string()) {
|
|
||||||
out.push_back(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
};
|
|
||||||
|
|
||||||
{
|
|
||||||
params.antiprompt.clear();
|
|
||||||
const auto stop = data.find("stop");
|
|
||||||
if (stop != data.end()) {
|
|
||||||
params.antiprompt = to_string_vec(*stop);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto grammar_trigger_words = data.find("grammar_trigger_words");
|
|
||||||
if (grammar_trigger_words != data.end()) {
|
|
||||||
for (const auto & word : to_string_vec(*grammar_trigger_words)) {
|
|
||||||
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
|
||||||
if (ids.size() == 1) {
|
|
||||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
params.sampling.grammar_trigger_words.push_back(word);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto samplers = data.find("samplers");
|
const auto samplers = data.find("samplers");
|
||||||
if (samplers != data.end()) {
|
if (samplers != data.end()) {
|
||||||
if (samplers->is_array()) {
|
if (samplers->is_array()) {
|
||||||
params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false);
|
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
|
||||||
} else if (samplers->is_string()){
|
} else if (samplers->is_string()){
|
||||||
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
||||||
}
|
}
|
||||||
|
@ -592,8 +587,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
json oaicompat_tools;
|
common_chat_msg oaicompat_chat_msg;
|
||||||
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
|
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
@ -688,18 +682,13 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
finish_reason = "stop";
|
finish_reason = oaicompat_chat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||||
}
|
}
|
||||||
|
|
||||||
json tool_calls;
|
json tool_calls;
|
||||||
json message_content;
|
if (!oaicompat_chat_msg.tool_calls.empty()) {
|
||||||
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
|
|
||||||
auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
|
|
||||||
if (!parsed_tool_calls.tool_calls.empty()) {
|
|
||||||
finish_reason = "tool_calls";
|
|
||||||
message_content = parsed_tool_calls.content;
|
|
||||||
tool_calls = json::array();
|
tool_calls = json::array();
|
||||||
for (const auto & tc : parsed_tool_calls.tool_calls) {
|
for (const auto & tc : oaicompat_chat_msg.tool_calls) {
|
||||||
tool_calls.push_back({
|
tool_calls.push_back({
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
|
@ -709,18 +698,13 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
{"id", tc.id.empty() ? json() : json(tc.id)},
|
{"id", tc.id.empty() ? json() : json(tc.id)},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
message_content = parsed_tool_calls.content;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
message_content = content;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
json choice {
|
json choice {
|
||||||
{"finish_reason", finish_reason},
|
{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"message", json {
|
{"message", json {
|
||||||
{"content", message_content},
|
{"content", oaicompat_chat_msg.content},
|
||||||
{"tool_calls", tool_calls},
|
{"tool_calls", tool_calls},
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
}},
|
}},
|
||||||
|
@ -812,6 +796,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
|
common_chat_msg oaicompat_chat_msg;
|
||||||
|
std::shared_ptr<common_chat_parser> chat_parser;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
@ -1234,6 +1220,8 @@ struct server_slot {
|
||||||
|
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
|
std::shared_ptr<common_chat_parser> chat_parser;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
|
@ -2260,6 +2248,10 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
||||||
|
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
|
||||||
|
if (!opt_msg) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||||
|
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
|
@ -2275,6 +2267,7 @@ struct server_context {
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
|
res->oaicompat_chat_msg = *opt_msg;
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
|
@ -2315,8 +2308,7 @@ struct server_context {
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
res->oaicompat_tools = slot.params.oaicompat_tools;
|
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
|
||||||
res->oaicompat_tool_call_style = slot.params.oaicompat_tool_call_style;
|
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
|
@ -3776,7 +3768,7 @@ int main(int argc, char ** argv) {
|
||||||
std::function<bool()> is_connection_closed,
|
std::function<bool()> is_connection_closed,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
oaicompat_type oaicompat,
|
oaicompat_type oaicompat,
|
||||||
common_tool_call_style tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE) {
|
const common_chat_template * tmpl) {
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
|
@ -3788,6 +3780,20 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
|
||||||
|
std::string prompt;
|
||||||
|
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||||
|
auto chat_data = common_chat_init(*tmpl, {
|
||||||
|
/* .messages = */ json_data(data, "messages", json::array()),
|
||||||
|
/* .tools = */ json_data(data, "tools", json()),
|
||||||
|
/
|
||||||
|
});
|
||||||
|
|
||||||
|
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
|
||||||
|
} else {
|
||||||
|
prompt = data.at("prompt").get<std::string>();
|
||||||
|
}
|
||||||
|
task.params.chat_parser = common_chat_init()
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||||
tasks.reserve(tokenized_prompts.size());
|
tasks.reserve(tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
|
@ -3800,12 +3806,14 @@ int main(int argc, char ** argv) {
|
||||||
task.params = server_task::params_from_json_cmpl(
|
task.params = server_task::params_from_json_cmpl(
|
||||||
ctx_server.ctx,
|
ctx_server.ctx,
|
||||||
ctx_server.params_base,
|
ctx_server.params_base,
|
||||||
|
nullptr,
|
||||||
data);
|
data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.oaicompat = oaicompat;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
|
task.params.chat_parser = common_chat_init()
|
||||||
task.params.oaicompat_tools = json_value(data, "tools", json());
|
task.params.oaicompat_tools = json_value(data, "tools", json());
|
||||||
task.params.oaicompat_tool_call_style = tool_call_style;
|
task.params.oaicompat_tool_call_style = tool_call_style;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
@ -3983,18 +3991,16 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
auto body = json::parse(req.body);
|
auto body = json::parse(req.body);
|
||||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||||
auto tool_call_style = common_tool_call_style_detect(chat_template);
|
LOG_INF("Request: %s\n", body.dump(2).c_str());
|
||||||
LOG_INF("Tool call style: %s\n", common_tool_call_style_name(tool_call_style).c_str());
|
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(body, chat_template, tool_call_style, params.use_jinja);
|
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||||
|
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
req.is_connection_closed,
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_CHAT,
|
OAICOMPAT_TYPE_CHAT);
|
||||||
tool_call_style);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "minja.hpp"
|
#include "minja.hpp"
|
||||||
|
#include "chat-handler.hpp"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
#include "tool-call.h"
|
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
@ -581,24 +581,18 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||||
static json oaicompat_completion_params_parse(
|
static json oaicompat_completion_params_parse(
|
||||||
const json & body, /* openai api json semantics */
|
const json & body, /* openai api json semantics */
|
||||||
const common_chat_template & tmpl,
|
const common_chat_template & tmpl,
|
||||||
common_tool_call_style tool_call_style,
|
|
||||||
bool use_jinja)
|
bool use_jinja)
|
||||||
{
|
{
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
auto tools = json_value(body, "tools", json());
|
auto tools = json_value(body, "tools", json());
|
||||||
auto has_tools = tools.is_array() && !tools.empty();
|
|
||||||
auto stream = json_value(body, "stream", false);
|
auto stream = json_value(body, "stream", false);
|
||||||
|
|
||||||
if (has_tools) {
|
if (tools.is_array() && !tools.empty()) {
|
||||||
if (stream) {
|
if (stream) {
|
||||||
throw std::runtime_error("Cannot use tools with stream");
|
throw std::runtime_error("Cannot use tools with stream");
|
||||||
}
|
}
|
||||||
if (use_jinja) {
|
if (!use_jinja) {
|
||||||
if (tool_call_style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_UNKNOWN) {
|
|
||||||
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("tools param requires --jinja flag");
|
throw std::runtime_error("tools param requires --jinja flag");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -627,32 +621,16 @@ static json oaicompat_completion_params_parse(
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
bool allow_content = tool_choice != "required";
|
|
||||||
if (tool_choice != "none" && has_tools) {
|
|
||||||
llama_params["tools"] = tools;
|
llama_params["tools"] = tools;
|
||||||
llama_params["tool_call_style"] = tool_call_style;
|
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||||
|
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||||
auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json();
|
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
|
||||||
|
|
||||||
auto handler = common_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
|
|
||||||
llama_params["prompt"] = handler.prompt;
|
|
||||||
|
|
||||||
for (const auto & stop : handler.additional_stops) {
|
|
||||||
llama_params["stop"].push_back(stop);
|
|
||||||
}
|
}
|
||||||
if (!handler.grammar_triggers.empty()) {
|
llama_params["tool_choice"] = tool_choice;
|
||||||
llama_params["grammar_trigger_words"] = handler.grammar_triggers;
|
llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false);
|
||||||
}
|
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
||||||
if (!handler.grammar.empty()) {
|
|
||||||
if (llama_params.contains("grammar")) {
|
|
||||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||||
}
|
}
|
||||||
llama_params["grammar"] = handler.grammar;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,7 +133,7 @@ llama_target_and_test(test-chat-template.cpp)
|
||||||
# llama_target_and_test(test-opt.cpp) # SLOW
|
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||||
llama_target_and_test(test-gguf.cpp)
|
llama_target_and_test(test-gguf.cpp)
|
||||||
llama_target_and_test(test-backend-ops.cpp)
|
llama_target_and_test(test-backend-ops.cpp)
|
||||||
llama_target_and_test(test-tool-call.cpp)
|
llama_target_and_test(test-chat-handler.cpp)
|
||||||
|
|
||||||
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||||
llama_target_and_test(test-autorelease.cpp LABEL "model")
|
llama_target_and_test(test-autorelease.cpp LABEL "model")
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#include "tool-call.h"
|
#include "chat-handler.hpp"
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ static void assert_equals(const T & expected, const T & actual) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string read_file(const std::string &path) {
|
static std::string read_file(const std::string &path) {
|
||||||
|
std::cout << "# Reading: " << path << std::endl << std::flush;
|
||||||
std::ifstream fs(path, std::ios_base::binary);
|
std::ifstream fs(path, std::ios_base::binary);
|
||||||
if (!fs.is_open()) {
|
if (!fs.is_open()) {
|
||||||
fs = std::ifstream("../" + path, std::ios_base::binary);
|
fs = std::ifstream("../" + path, std::ios_base::binary);
|
||||||
|
@ -76,10 +77,16 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool
|
||||||
assert_equals(expected_content, result.content);
|
assert_equals(expected_content, result.content);
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tc : result.tool_calls) {
|
for (const auto & tc : result.tool_calls) {
|
||||||
|
auto arguments = tc.arguments;
|
||||||
|
try {
|
||||||
|
arguments = dump(json::parse(arguments));
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
auto tool_call = json {
|
auto tool_call = json {
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"arguments", dump(json::parse(tc.arguments))},
|
{"arguments", arguments},
|
||||||
{"name", tc.name},
|
{"name", tc.name},
|
||||||
}},
|
}},
|
||||||
};
|
};
|
||||||
|
@ -94,8 +101,7 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool
|
||||||
assert_equals(expected, actual);
|
assert_equals(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
const json tools = json::parse(R"([
|
const auto special_function_tool = json::parse(R"({
|
||||||
{
|
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "special_function",
|
"name": "special_function",
|
||||||
|
@ -111,25 +117,28 @@ const json tools = json::parse(R"([
|
||||||
"required": ["arg1"]
|
"required": ["arg1"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
})");
|
||||||
{
|
const auto python_tool = json::parse(R"({
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"description": "a python interpreter",
|
"description": "an ipython interpreter",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"code": {
|
"code": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The code."
|
"description": "Python code to execute."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["code"]
|
"required": ["code"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
})");
|
||||||
])");
|
const auto code_interpreter_tool = json::parse(R"({
|
||||||
|
"type": "code_interpreter"
|
||||||
|
})");
|
||||||
|
const json tools = {special_function_tool, code_interpreter_tool};
|
||||||
|
|
||||||
static void test_parsing() {
|
static void test_parsing() {
|
||||||
json request = {
|
json request = {
|
||||||
|
@ -226,9 +235,7 @@ static void test_parsing() {
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "python"},
|
{"name", "python"},
|
||||||
{"arguments", dump({
|
{"arguments", "this could be anything"},
|
||||||
{"code", "this could be anything"}
|
|
||||||
})}
|
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
|
@ -238,7 +245,7 @@ static void test_parsing() {
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "python"},
|
{"name", "python"},
|
||||||
{"arguments", dump({{"code", ""}})}
|
{"arguments", ""},
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
auto special_function_call = json {
|
auto special_function_call = json {
|
||||||
|
@ -332,6 +339,8 @@ static void test_tool_call_style_detection() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
|
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||||
|
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
||||||
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
||||||
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
||||||
|
|
||||||
|
@ -354,9 +363,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
||||||
return delta;
|
return delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||||
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
|
||||||
const common_chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
|
||||||
auto tool_call_style = common_tool_call_style_detect(tmpl);
|
auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||||
auto & tool_calls = tool_calling_message.at("tool_calls");
|
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||||
|
|
||||||
|
@ -404,24 +411,77 @@ static void test_grammars() {
|
||||||
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
|
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
|
||||||
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
|
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
|
||||||
|
|
||||||
test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools,
|
auto python_tool_call_message = json {
|
||||||
/* skip_grammar_test= */ true);
|
{"role", "assistant"},
|
||||||
test_template("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
{"content", ""},
|
||||||
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
{"tool_calls", json {{
|
||||||
test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
|
{"type", "function"},
|
||||||
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
{"function", {
|
||||||
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
{"name", "python"},
|
||||||
test_template("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
{"arguments", "print('hey')"}
|
||||||
test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
}},
|
||||||
test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
}}}
|
||||||
test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "<s>", "</s>", { "<|eot_id|>" }, tool_call_message, tools);
|
};
|
||||||
test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "<s>", "</s>", { "<end_of_turn>" }, tool_call_message_with_id, tools);
|
|
||||||
test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "<s>", "</s>", { "<|end|>" }, tool_call_message_with_id, tools);
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// assert_equals(tmpl.requires_object_arguments_, true);
|
||||||
|
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||||
|
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
// }
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
|
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||||
|
test_template(tmpl, { "<end_of_turn>" }, tool_call_message_with_id, tools);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||||
|
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
test_tool_call_style_detection();
|
// test_tool_call_style_detection();
|
||||||
test_parsing();
|
// test_parsing();
|
||||||
test_grammars();
|
test_grammars();
|
||||||
|
|
||||||
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
std::cout << "\n[tool-call] All tests passed!" << std::endl;
|
Loading…
Add table
Add a link
Reference in a new issue