WIP chat handlers

This commit is contained in:
Olivier Chafik 2025-01-24 02:31:37 +00:00
parent 46415d7a51
commit 36ed106f84
13 changed files with 1071 additions and 999 deletions

View file

@ -1363,10 +1363,11 @@ llama-server: \
examples/server/httplib.h \ examples/server/httplib.h \
examples/server/index.html.hpp \ examples/server/index.html.hpp \
examples/server/loading.html.hpp \ examples/server/loading.html.hpp \
common/chat-handler.cpp \
common/chat-handler.hpp \
common/chat-template.hpp \ common/chat-template.hpp \
common/json.hpp \ common/json.hpp \
common/minja.hpp \ common/minja.hpp \
common/tool-call.h \
$(OBJ_ALL) $(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)

View file

@ -56,6 +56,8 @@ add_library(${TARGET} STATIC
arg.cpp arg.cpp
arg.h arg.h
base64.hpp base64.hpp
chat-handler.cpp
chat-handler.hpp
chat-template.hpp chat-template.hpp
common.cpp common.cpp
common.h common.h
@ -72,7 +74,6 @@ add_library(${TARGET} STATIC
sampling.h sampling.h
speculative.cpp speculative.cpp
speculative.h speculative.h
tool-call.cpp
) )
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)

720
common/chat-handler.cpp Normal file
View file

@ -0,0 +1,720 @@
#include "chat-handler.hpp"
#include "chat-template.hpp"
#include "json-schema-to-grammar.h"
#include "minja.hpp"
const common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ false,
// /* .compact_spaces = */ true,
};
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
this->position = position - 1;
this->found_error = true;
return false;
}
bool null() override { return true; }
bool boolean(bool) override { return true; }
bool number_integer(number_integer_t) override { return true; }
bool number_unsigned(number_unsigned_t) override { return true; }
bool number_float(number_float_t, const string_t &) override { return true; }
bool string(string_t &) override { return true; }
bool binary(binary_t &) override { return true; }
bool start_object(std::size_t) override { return true; }
bool key(string_t &) override { return true; }
bool end_object() override { return true; }
bool start_array(std::size_t) override { return true; }
bool end_array() override { return true; }
};
json_error_locator err_loc;
json::sax_parse(it, end, &err_loc);
std::string::const_iterator temptative_end;
if (err_loc.found_error) {
temptative_end = it + err_loc.position;
} else {
temptative_end = end;
}
std::string json_sub {it, temptative_end};
try {
out = json::parse(json_sub);
it = temptative_end;
return true;
} catch (const std::exception &) {
return false;
}
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
std::smatch match;
common_chat_msg result;
result.role = "assistant";
auto end = input.end();
auto it = input.begin();
std::vector<std::string> tool_names;
if (check_names) {
for (const auto & tool : tools) {
if (!tool.contains("type")) {
continue;
}
std::string type = tool.at("type");
if (type == "function") {
tool_names.push_back(tool["function"]["name"]);
} else if (type == "code_interpreter") {
tool_names.push_back("ipython");
}
}
}
while (it != end) {
std::sregex_iterator rend;
std::sregex_iterator rit(it, end, function_regex);
if (rit == rend) {
fprintf(stderr, "No more tool calls found\n");
result.content += std::string(it, end);
break;
}
auto name = rit->str(1);
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
result.content += std::string(it, rit->suffix().first);
break;
}
result.content += std::string(it, rit->prefix().second);
it = rit->suffix().first;
json arguments;
if (!parse_json(it, end, arguments)) {
throw std::runtime_error("Failed to parse json tool call arguments");
}
if (!std::regex_search(it, end, match, close_regex)) {
throw std::runtime_error("Malformed input, missing closing pattern");
}
it = match.suffix().first;
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
}
return result;
}
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
auto content_end = input.find(prefix);
size_t tc_start = std::string::npos;
common_chat_msg result;
result.role = "assistant";
const auto process_tool_calls = [&](const json & tool_calls) {
for (const auto & tool_call : tool_calls) {
const auto & arguments = tool_call["arguments"];
result.tool_calls.push_back({
tool_call["name"],
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tool_call.contains("id") ? tool_call["id"] : "",
});
}
};
if (content_end == std::string::npos) {
result.content = input;
} else {
tc_start = content_end + prefix.size() - rstrip_prefix;
result.content = input.substr(0, content_end);
auto tool_calls = json::parse(input.substr(tc_start));
process_tool_calls(tool_calls);
}
return result;
}
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
json messages_with_system = messages;
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
std::string existing_system = messages_with_system.at(0).at("content");
messages_with_system[0] = json {
{"role", "system"},
{"content", existing_system + "\n" + system_prompt},
};
} else {
messages_with_system.insert(messages_with_system.begin(), json {
{"role", "system"},
{"content", system_prompt},
});
}
return messages_with_system;
}
class text_chat_parser : public common_chat_parser {
public:
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
return parse_final(input);
}
common_chat_msg parse_final(const std::string & input) override {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
};
class monolithic_chat_parser : public common_chat_parser {
std::string input_buffer_;
std::function<common_chat_msg(const std::string & input)> parse_final_;
public:
monolithic_chat_parser(const std::function<common_chat_msg(const std::string & input)> & parse_final) : parse_final_(parse_final) {}
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
input_buffer_ += input;
return std::nullopt;
}
common_chat_msg parse_final(const std::string & input) override {
input_buffer_ += input;
auto out = parse_final_(input_buffer_);
input_buffer_.clear();
return out;
}
};
static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
auto tool_call_schemas = json::array();
for (const auto & tool : params.tools) {
const auto & function = tool["function"];
auto tool_schema = json {
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}},
{"required", json::array({"name", "arguments"})},
};
if (function.contains("description")) {
tool_schema["description"] = function["description"];
}
if (params.parallel_tool_calls) {
tool_schema["properties"]["id"] = {
{"type", "string"},
{"minLength", 4},
};
tool_schema["required"].push_back("id");
}
tool_call_schemas.emplace_back(tool_schema);
}
const auto tool_call =
params.parallel_tool_calls
? json {
{"type", "object"},
{"properties", {
{"tool_calls", {
{"type", "array"},
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
{"anyOf", tool_call_schemas},
}},
{"minItems", 1},
}},
}},
{"required", json::array({"tool_calls"})},
}
: json {
{"type", "object"},
{"properties", {
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
{"anyOf", tool_call_schemas},
}},
}},
{"required", json::array({"tool_call"})},
};
const auto schema =
params.tool_choice != "required"
? json {
{"anyOf", json::array({
tool_call,
{
{"type", "object"},
{"properties", {
{"response", params.json_schema.is_null()
? json {{"type", "string"}}
: params.json_schema
},
}},
{"required", json::array({"response"})},
},
})}
}
: tool_call;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_schema("root", schema);
}, grammar_options);
// TODO: add schema to system prompt.
auto tweaked_messages = add_system(
params.messages,
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
json data = json::parse(input);
common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({
tool_call["name"],
tool_call["arguments"].dump(),
tool_call.contains("id") ? tool_call["id"] : "",
});
}
} else if (data.contains("tool_call")) {
result.tool_calls.push_back({
data["tool_call"]["name"],
data["tool_call"]["arguments"].dump(),
/* id= */ "",
});
} else if (data.contains("response")) {
const auto & response = data["response"];
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
});
return data;
}
static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : params.tools) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
// Important note: the model is probably trained to take a JSON stringified arguments value.
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
{"id", {
{"type", "string"},
// Nemo's template expects a 9-character alphanumeric ID.
{"pattern", "^[a-zA-Z0-9]{9}$"},
}},
}},
{"required", json::array({"name", "arguments", "id"})},
});
}
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!params.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}, grammar_options);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
});
return data;
}
static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
for (const auto & tool : params.tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
builtin_tools.push_back("code_interpreter");
break;
}
}
common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto has_python = false;
for (const auto & tool : params.tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
has_python = true;
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required" && !eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
}
}
}
}
if (has_python) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
}
if (params.tool_choice != "required" && eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true});
}
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
if (uses_python_tag) {
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ match[1].str(),
/* .id = */ "",
},
}
};
}
}
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
});
return data;
}
static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : params.tools) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}},
{"required", json::array({"name", "arguments", "id"})},
});
}
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!params.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
});
return data;
}
static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
auto has_python = false;
for (const auto & tool : params.tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true});
data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false});
}
}
}
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
// Note: if there's a python rule, it needs to come last.
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
if (has_python && params.tool_choice != "required") {
data.grammar_triggers.push_back({"python\n", /* .at_start = */ true});
data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false});
}
if (params.parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
} else {
builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : ""));
}
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
});
return data;
}
static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto has_python = false;
for (const auto & tool : params.tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"];
std::string name = function["name"];
if (name == "python" || name == "ipython") {
has_python = true;
} else {
auto parameters = function["parameters"];
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
}
}
}
if (has_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
}
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ match[1].str(),
/* .id = */ "",
},
}
};
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false);
});
return data;
}
static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
common_chat_data data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : params.tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"},
{"properties", json {
{"name", json {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
}));
}
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
}
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {"assistant", input, {}};
}
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
json call;
if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call");
}
result.tool_calls.push_back({
call["name"],
call["arguments"].dump(),
/* id= */ "",
});
rit = {it, end, middle_pattern};
if (rit != rend) {
it = rit->suffix().first;
} else {
rit = {it, end, end_pattern};
if (rit == rend) {
throw std::runtime_error("Malformed input, missing </tool_call>");
}
break;
}
}
return result;
} catch (const std::exception & e) {
return {"assistant", input, {}};
}
});
return data;
}
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
if (params.tools.is_null()) {
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.handler = std::make_unique<text_chat_parser>();
return data;
}
const auto & src = tmpl.source();
if (src.find("<tool_call>") != std::string::npos) {
return build_hermes_2_pro_tool_call_handler(tmpl, params);
}
if (src.find(">>>all") != std::string::npos) {
return build_functionary_v3_llama_3_tool_call_handler(tmpl, params);
}
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return build_functionary_v3_llama_3_1_tool_call_handler(tmpl, params);
}
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
// as it seems to be outputting some JSON.
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json);
}
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
// TODO: Command-R-Plus
// }
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return build_mistral_nemo_tool_call_handler(tmpl, params);
}
if (src.find(" functools[") != std::string::npos) {
return build_firefunction_v2_tool_call_handler(tmpl, params);
}
return build_generic_tool_call_handler(tmpl, params);
}

43
common/chat-handler.hpp Normal file
View file

@ -0,0 +1,43 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "common.h"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
struct common_chat_params {
json messages;
json tools;
json tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
};
class common_chat_parser {
public:
virtual ~common_chat_parser() = default;
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
virtual common_chat_msg parse_final(const std::string & input) = 0;
};
struct common_chat_data {
std::string prompt;
std::string grammar;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> additional_stops;
std::unique_ptr<class common_chat_parser> handler;
};
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);

View file

@ -20,7 +20,7 @@ namespace minja {
class chat_template { class chat_template {
public: public:
private: // private:
bool supports_tools_ = true; bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified. // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
@ -28,6 +28,7 @@ class chat_template {
bool requires_typed_content_ = false; bool requires_typed_content_ = false;
bool supports_system_role_ = true; bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false; bool supports_parallel_tool_calls_ = false;
bool supports_code_interpreter_ = false;
std::string source_; std::string source_;
std::string bos_token_; std::string bos_token_;
std::string eos_token_; std::string eos_token_;
@ -60,28 +61,7 @@ class chat_template {
}); });
supports_tools_ = source.find("tools") != std::string::npos; supports_tools_ = source.find("tools") != std::string::npos;
auto renders_string_arguments = requires_object_arguments_ =
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_raw_render({ try_raw_render({
{ {
{"role", "user"}, {"role", "user"},
@ -102,9 +82,27 @@ class chat_template {
}, },
})}, })},
} }
}, {}, false).find("{\"code\": \"print") != std::string::npos; }, {}, false).find("{\"code\": \"print") != std::string::npos
requires_object_arguments_ = renders_object_arguments; && try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
} }
}, {}, false).find("{\"code\": \"print") == std::string::npos;
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
supports_system_role_ = try_raw_render({ supports_system_role_ = try_raw_render({
@ -114,6 +112,8 @@ class chat_template {
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos; && try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
supports_code_interpreter_ = source.find("code_interpreter") != std::string::npos;
} }
const std::string & source() const { return source_; } const std::string & source() const { return source_; }
@ -130,8 +130,45 @@ class chat_template {
bool adjust_inputs = true) const bool adjust_inputs = true) const
{ {
json actual_messages; json actual_messages;
json actual_tools;
// First, "fix" messages so they have a chance to be rendered correctly by the template auto has_code_interpreter = false;
for (const auto & tool : tools) {
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
has_code_interpreter = true;
break;
}
}
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) {
actual_tools = json::array();
for (const auto & tool : tools) {
if (tool.contains("type") && tool.at("type") == "code_interpreter") {
static const auto python_tool = json::parse(R"({
"type": "function",
"function": {
"name": "ipython",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to run in the ipython interpreter."
}
},
"required": ["code"]
}
}
})");
actual_tools.push_back(python_tool);
} else {
actual_tools.push_back(tool);
}
}
} else if (!tools.is_null()) {
actual_tools = tools;
}
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) { if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
actual_messages = json::array(); actual_messages = json::array();
@ -173,7 +210,12 @@ class chat_template {
if (tool_call["type"] == "function") { if (tool_call["type"] == "function") {
auto & function = tool_call.at("function"); auto & function = tool_call.at("function");
std::string arguments = function.at("arguments"); std::string arguments = function.at("arguments");
try {
function["arguments"] = json::parse(arguments); function["arguments"] = json::parse(arguments);
} catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
function["arguments"] = arguments;
}
} }
} }
} }
@ -242,6 +284,9 @@ class chat_template {
} else { } else {
actual_messages = messages; actual_messages = messages;
} }
// if (adjust_inputs) {
// fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str());
// }
auto context = minja::Context::make(json({ auto context = minja::Context::make(json({
{"messages", actual_messages}, {"messages", actual_messages},
@ -251,8 +296,12 @@ class chat_template {
})); }));
if (!tools.is_null()) { if (!tools.is_null()) {
auto tools_val = minja::Value(tools); auto tools_val = minja::Value(actual_tools);
context->set("tools", tools_val); context->set("tools", tools_val);
if (has_code_interpreter) {
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
context->set("builtin_tools", builtin_tools_val);
}
} }
if (!extra_context.is_null()) { if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) { for (auto & kv : extra_context.items()) {

View file

@ -109,6 +109,11 @@ enum common_conversation_mode {
COMMON_CONVERSATION_MODE_AUTO = 2, COMMON_CONVERSATION_MODE_AUTO = 2,
}; };
struct common_grammar_trigger {
std::string word;
bool at_start;
};
// sampling parameters // sampling parameters
struct common_params_sampling { struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -155,7 +160,7 @@ struct common_params_sampling {
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to enable grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply

View file

@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
std::vector<const char *> trigger_words; std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size()); trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) { for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.c_str()); trigger_words.push_back(str.word.c_str());
} }
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,

View file

@ -1,747 +0,0 @@
#include "tool-call.h"
#include "json-schema-to-grammar.h"
#include <algorithm>
#include <fstream>
#include <map>
#include <regex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using json = nlohmann::ordered_json;
static json normalize_tools(const json & tools) {
static const auto python_tool = json::parse(R"({
"type": "function",
"function": {
"name": "python",
"description": "Runs code in an Python interpreter and returns the result of the execution after 60 seconds.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to run in the Python interpreter."
}
},
"required": ["code"]
}
}
})");
auto results = json::array();
for (const auto & tool : tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
results.push_back(python_tool);
} else if (tool["type"] == "function") {
results.push_back(tool);
} else {
continue;
}
}
return results;
}
std::string common_tool_call_style_name(common_tool_call_style style) {
switch (style) {
case COMMON_TOOL_CALL_STYLE_NONE:
return "None";
case COMMON_TOOL_CALL_STYLE_GENERIC:
return "Generic";
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
return "Llama-3.1";
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
return "Llama-3.2";
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
return "FunctionaryV3Llama3";
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
return "FunctionaryV3Llama3.1";
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
return "Hermes2Pro";
case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS:
return "CommandRPlus";
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
return "MistralNemo";
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
return "FirefunctionV2";
default:
return "Unknown";
}
}
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template) {
const auto & src = chat_template.source();
if (src.find("<tool_call>") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO;
} else if (src.find(">>>all") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3;
} else if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1;
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (src.find("<|python_tag|>") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
} else {
return COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
}
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS;
} else if (src.find("[TOOL_CALLS]") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO;
} else if (src.find(" functools[") != std::string::npos) {
return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2;
} else {
return COMMON_TOOL_CALL_STYLE_GENERIC;
}
}
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
this->position = position - 1;
this->found_error = true;
return false;
}
bool null() override { return true; }
bool boolean(bool) override { return true; }
bool number_integer(number_integer_t) override { return true; }
bool number_unsigned(number_unsigned_t) override { return true; }
bool number_float(number_float_t, const string_t &) override { return true; }
bool string(string_t &) override { return true; }
bool binary(binary_t &) override { return true; }
bool start_object(std::size_t) override { return true; }
bool key(string_t &) override { return true; }
bool end_object() override { return true; }
bool start_array(std::size_t) override { return true; }
bool end_array() override { return true; }
};
json_error_locator err_loc;
json::sax_parse(it, end, &err_loc);
std::string::const_iterator temptative_end;
if (err_loc.found_error) {
temptative_end = it + err_loc.position;
} else {
temptative_end = end;
}
std::string json_sub {it, temptative_end};
try {
out = json::parse(json_sub);
it = temptative_end;
return true;
} catch (const std::exception &) {
return false;
}
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
std::smatch match;
common_chat_msg result;
result.role = "assistant";
auto end = input.end();
auto it = input.begin();
std::unordered_set<std::string> tool_names;
if (check_names) {
for (const auto & tool : tools) {
if (!tool.contains("type")) {
continue;
}
std::string type = tool.at("type");
if (type == "function") {
tool_names.insert(tool["function"]["name"]);
} else if (type == "code_interpreter") {
tool_names.insert("python");
}
}
}
while (it != end) {
std::sregex_iterator rend;
std::sregex_iterator rit(it, end, function_regex);
if (rit == rend) {
result.content += std::string(it, end);
break;
}
auto name = rit->str(1);
if (check_names && tool_names.find(name) == tool_names.end()) {
result.content += std::string(it, rit->suffix().first);
break;
}
result.content += std::string(it, rit->prefix().second);
it = rit->suffix().first;
json arguments;
if (!parse_json(it, end, arguments)) {
throw std::runtime_error("Failed to parse json tool call arguments");
}
if (!std::regex_search(it, end, match, close_regex)) {
throw std::runtime_error("Malformed input, missing closing pattern");
}
it = match.suffix().first;
result.tool_calls.push_back({name, arguments.dump(), /* id= */ ""});
}
return result;
}
static common_chat_msg parse_hermes_tool_calls(const std::string& input) {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {"assistant", input, {}};
}
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
json call;
if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call");
}
result.tool_calls.push_back({
call["name"],
call["arguments"].dump(),
/* id= */ "",
});
rit = {it, end, middle_pattern};
if (rit != rend) {
it = rit->suffix().first;
} else {
rit = {it, end, end_pattern};
if (rit == rend) {
throw std::runtime_error("Malformed input, missing </tool_call>");
}
break;
}
}
return result;
} catch (const std::exception & e) {
return {"assistant", input, {}};
}
}
static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
if (allow_python_tag) {
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ (json {{"code", match[1].str()}}).dump(),
/* .id = */ "",
},
}
};
}
}
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
}
static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ (json {{"code", match[1].str()}}).dump(),
/* .id = */ "",
},
}
};
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
}
static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
}
static common_chat_msg parse_generic_tool_calls(const std::string& input) {
json data = json::parse(input);
common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({
tool_call["name"],
tool_call["arguments"].dump(),
tool_call.contains("id") ? tool_call["id"] : "",
});
}
} else if (data.contains("tool_call")) {
result.tool_calls.push_back({
data["tool_call"]["name"],
data["tool_call"]["arguments"].dump(),
/* id= */ "",
});
} else if (data.contains("response")) {
const auto & response = data["response"];
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
}
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
auto content_end = input.find(prefix);
size_t tc_start = std::string::npos;
common_chat_msg result;
result.role = "assistant";
const auto process_tool_calls = [&](const json & tool_calls) {
for (const auto & tool_call : tool_calls) {
const auto & arguments = tool_call["arguments"];
result.tool_calls.push_back({
tool_call["name"],
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tool_call.contains("id") ? tool_call["id"] : "",
});
}
};
if (content_end == std::string::npos) {
result.content = input;
} else {
tc_start = content_end + prefix.size() - rstrip_prefix;
result.content = input.substr(0, content_end);
auto tool_calls = json::parse(input.substr(tc_start));
process_tool_calls(tool_calls);
}
return result;
}
static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
switch (style) {
case COMMON_TOOL_CALL_STYLE_NONE:
return {"assistant", input, {}};
case COMMON_TOOL_CALL_STYLE_GENERIC:
return parse_generic_tool_calls(input);
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false);
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
return parse_functionary_v3_tool_calls(tools, input);
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
return parse_hermes_tool_calls(input);
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
return parse_mistral_nemo_tool_calls(input);
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
return parse_firefunction_v2_tool_calls(input);
default:
throw std::runtime_error("Unsupported tool call style");
}
}
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
json messages_with_system = messages;
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
std::string existing_system = messages_with_system.at(0).at("content");
messages_with_system[0] = json {
{"role", "system"},
{"content", existing_system + "\n" + system_prompt},
};
} else {
messages_with_system.insert(messages_with_system.begin(), json {
{"role", "system"},
{"content", system_prompt},
});
}
return messages_with_system;
}
common_tool_call_handler common_tool_call_handler_init(
common_tool_call_style style,
const common_chat_template & tmpl,
bool allow_content,
const nlohmann::ordered_json & parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
const nlohmann::ordered_json & json_schema)
{
common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ true,
};
common_tool_call_handler handler;
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
switch (style) {
case COMMON_TOOL_CALL_STYLE_NONE:
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
break;
case COMMON_TOOL_CALL_STYLE_GENERIC: {
auto actual_tools = normalize_tools(tools);
auto tool_call_schemas = json::array();
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
auto tool_schema = json {
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}},
{"required", json::array({"name", "arguments"})},
};
if (function.contains("description")) {
tool_schema["description"] = function["description"];
}
if (parallel) {
tool_schema["properties"]["id"] = {
{"type", "string"},
{"minLength", 4},
};
tool_schema["required"].push_back("id");
}
tool_call_schemas.emplace_back(tool_schema);
}
const auto tool_call =
parallel
? json {
{"type", "object"},
{"properties", {
{"tool_calls", {
{"type", "array"},
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
{"anyOf", tool_call_schemas},
}},
{"minItems", 1},
}},
}},
{"required", json::array({"tool_calls"})},
}
: json {
{"type", "object"},
{"properties", {
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
{"anyOf", tool_call_schemas},
}},
}},
{"required", json::array({"tool_call"})},
};
const auto schema =
allow_content
? json {
{"anyOf", json::array({
tool_call,
{
{"type", "object"},
{"properties", {
{"response", json_schema.is_null()
? json {{"type", "string"}}
: json_schema
},
}},
{"required", json::array({"response"})},
},
})}
}
: tool_call;
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_schema("root", schema);
}, grammar_options);
// TODO: add schema to system prompt.
auto tweaked_messages = add_system(
messages,
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break;
}
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
// Important note: the model is probably trained to take a JSON stringified arguments value.
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
{"id", {
{"type", "string"},
// Nemo's template expects a 9-character alphanumeric ID.
{"pattern", "^[a-zA-Z0-9]{9}$"},
}},
}},
{"required", json::array({"name", "arguments", "id"})},
});
}
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}, grammar_options);
if (allow_content) {
handler.grammar_triggers.push_back("[TOOL_CALLS]");
}
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break;
}
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function["name"]},
}},
{"arguments", function["parameters"]},
}},
{"required", json::array({"name", "arguments", "id"})},
});
}
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
if (allow_content) {
handler.grammar_triggers.push_back(" functools[");
}
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break;
}
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: {
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
for (const auto & tool : tools) {
if (!tool.contains("type")) {
continue;
}
if (tool["type"] == "code_interpreter") {
builtin_tools.push_back("code_interpreter");
break;
}
}
auto actual_tools = normalize_tools(tools);
auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
// as it seems to be outputting some JSON.
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (uses_python_tag && (name == "ipython" || builtin_tools.contains(name))) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_triggers.push_back("<|python_tag|>");
}
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (allow_content && !eagerly_match_any_json) {
handler.grammar_triggers.push_back("{\"name\": \"" + name + "\"");
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
handler.grammar_triggers.push_back("{\n\t\"name\": \"" + name + "\"");
handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\"");
handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\"");
handler.grammar_triggers.push_back("{\"type\": \"function\", \"name\": \"" + name + "\"");
}
}
}
if (allow_content && eagerly_match_any_json) {
handler.grammar_triggers.push_back("{\"");
handler.grammar_triggers.push_back("{\n\t\"");
handler.grammar_triggers.push_back("{\n \"");
handler.grammar_triggers.push_back("{\n \"");
}
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
handler.additional_stops.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools},
});
break;
}
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule));
if (allow_content) {
handler.grammar_triggers.push_back(name + "\n");
handler.grammar_triggers.push_back("\n>>>" + name + "\n");
}
}
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (parallel) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else {
builder.add_rule("root", first_rule);
}
}, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
if (name == "python" || name == "ipython") {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_triggers.push_back("<|python_tag|>");
}
} else {
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
}
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_triggers.push_back("<function=");
}
}, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: {
// NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"},
{"properties", json {
{"name", json {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
}));
}
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_triggers.push_back("<tool_call>");
}
}, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break;
}
default:
throw std::runtime_error("Unsupported tool call style");
}
return handler;
}

View file

@ -1,44 +0,0 @@
#pragma once
#include "ggml.h"
#include "common.h"
#include "chat-template.hpp"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
enum common_tool_call_style {
COMMON_TOOL_CALL_STYLE_UNKNOWN,
COMMON_TOOL_CALL_STYLE_NONE,
COMMON_TOOL_CALL_STYLE_GENERIC,
COMMON_TOOL_CALL_STYLE_LLAMA_3_1,
COMMON_TOOL_CALL_STYLE_LLAMA_3_2,
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3,
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1,
COMMON_TOOL_CALL_STYLE_HERMES_2_PRO,
COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS,
COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO,
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
};
struct common_tool_call_handler {
std::string prompt;
std::string grammar;
std::vector<std::string> grammar_triggers;
std::vector<std::string> additional_stops;
};
std::string common_tool_call_style_name(common_tool_call_style style);
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
common_tool_call_handler common_tool_call_handler_init(
common_tool_call_style style,
const common_chat_template & tmpl,
bool allow_content,
const nlohmann::ordered_json & parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
const nlohmann::ordered_json & json_schema = {});

View file

@ -117,8 +117,7 @@ struct slot_params {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
json oaicompat_tools; std::shared_ptr<common_chat_parser> chat_parser;
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
json to_json() const { json to_json() const {
std::vector<std::string> samplers; std::vector<std::string> samplers;
@ -166,7 +165,7 @@ struct slot_params {
{"n_probs", sampling.n_probs}, {"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep}, {"min_keep", sampling.min_keep},
{"grammar", sampling.grammar}, {"grammar", sampling.grammar},
{"grammar_trigger_words", sampling.grammar_trigger_words}, // {"grammar_trigger_words", sampling.grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"samplers", samplers}, {"samplers", samplers},
{"speculative.n_max", speculative.n_max}, {"speculative.n_max", speculative.n_max},
@ -212,6 +211,7 @@ struct server_task {
static slot_params params_from_json_cmpl( static slot_params params_from_json_cmpl(
const llama_context * ctx, const llama_context * ctx,
const common_params & params_base, const common_params & params_base,
const common_chat_template * tmpl,
const json & data) { const json & data) {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
@ -322,6 +322,41 @@ struct server_task {
} }
} }
{
params.antiprompt.clear();
const auto stop = data.find("stop");
if (stop != data.end()) {
params.antiprompt = *stop;
}
}
if (tmpl && params_base.use_jinja) {
common_chat_params chat_params;
chat_params.messages = json_value(data, "messages", json::array());
chat_params.tools = json_value(data, "tools", json());
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
chat_params.json_schema = json_value(data, "json_schema", json());
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
chat_params.stream = json_value(data, "stream", false);
auto chat_data = common_chat_init(*tmpl, chat_params);
params.chat_parser = std::move(chat_data.handler);
params.sampling.grammar = chat_data.grammar;
for (const auto & stop : chat_data.additional_stops) {
params.antiprompt.push_back(stop);
}
for (const auto & trigger : chat_data.grammar_triggers) {
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
// process "json_schema" and "grammar" // process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
@ -336,13 +371,7 @@ struct server_task {
} else { } else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
} }
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
if (data.contains("tools")) {
params.oaicompat_tools = data.at("tools");
}
if (data.contains("tool_call_style")) {
params.oaicompat_tool_call_style = data.at("tool_call_style");
}
{ {
params.sampling.logit_bias.clear(); params.sampling.logit_bias.clear();
@ -379,45 +408,11 @@ struct server_task {
} }
} }
auto to_string_vec = [](const json & j) {
std::vector<std::string> out;
if (j.is_array()) {
for (const auto & e : j) {
if (e.is_string()) {
out.push_back(e);
}
}
}
return out;
};
{
params.antiprompt.clear();
const auto stop = data.find("stop");
if (stop != data.end()) {
params.antiprompt = to_string_vec(*stop);
}
}
{
const auto grammar_trigger_words = data.find("grammar_trigger_words");
if (grammar_trigger_words != data.end()) {
for (const auto & word : to_string_vec(*grammar_trigger_words)) {
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
params.sampling.grammar_trigger_words.push_back(word);
}
}
}
{ {
const auto samplers = data.find("samplers"); const auto samplers = data.find("samplers");
if (samplers != data.end()) { if (samplers != data.end()) {
if (samplers->is_array()) { if (samplers->is_array()) {
params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false); params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
} else if (samplers->is_string()){ } else if (samplers->is_string()){
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>()); params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
} }
@ -592,8 +587,7 @@ struct server_task_result_cmpl_final : server_task_result {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
json oaicompat_tools; common_chat_msg oaicompat_chat_msg;
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -688,18 +682,13 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() { json to_json_oaicompat_chat() {
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop"; finish_reason = oaicompat_chat_msg.tool_calls.empty() ? "stop" : "tool_calls";
} }
json tool_calls; json tool_calls;
json message_content; if (!oaicompat_chat_msg.tool_calls.empty()) {
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls";
message_content = parsed_tool_calls.content;
tool_calls = json::array(); tool_calls = json::array();
for (const auto & tc : parsed_tool_calls.tool_calls) { for (const auto & tc : oaicompat_chat_msg.tool_calls) {
tool_calls.push_back({ tool_calls.push_back({
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
@ -709,18 +698,13 @@ struct server_task_result_cmpl_final : server_task_result {
{"id", tc.id.empty() ? json() : json(tc.id)}, {"id", tc.id.empty() ? json() : json(tc.id)},
}); });
} }
} else {
message_content = parsed_tool_calls.content;
}
} else {
message_content = content;
} }
json choice { json choice {
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", json { {"message", json {
{"content", message_content}, {"content", oaicompat_chat_msg.content},
{"tool_calls", tool_calls}, {"tool_calls", tool_calls},
{"role", "assistant"}, {"role", "assistant"},
}}, }},
@ -812,6 +796,8 @@ struct server_task_result_cmpl_partial : server_task_result {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_chat_msg;
std::shared_ptr<common_chat_parser> chat_parser;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -1234,6 +1220,8 @@ struct server_slot {
std::string stopping_word; std::string stopping_word;
std::shared_ptr<common_chat_parser> chat_parser;
// sampling // sampling
json json_schema; json json_schema;
@ -2260,6 +2248,10 @@ struct server_context {
} }
void send_partial_response(server_slot & slot, const completion_token_output & tkn) { void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
if (!opt_msg) {
return;
}
auto res = std::make_unique<server_task_result_cmpl_partial>(); auto res = std::make_unique<server_task_result_cmpl_partial>();
res->id = slot.id_task; res->id = slot.id_task;
@ -2275,6 +2267,7 @@ struct server_context {
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_msg = *opt_msg;
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -2315,8 +2308,7 @@ struct server_context {
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_tools = slot.params.oaicompat_tools; res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
res->oaicompat_tool_call_style = slot.params.oaicompat_tool_call_style;
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -3776,7 +3768,7 @@ int main(int argc, char ** argv) {
std::function<bool()> is_connection_closed, std::function<bool()> is_connection_closed,
httplib::Response & res, httplib::Response & res,
oaicompat_type oaicompat, oaicompat_type oaicompat,
common_tool_call_style tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE) { const common_chat_template * tmpl) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {
@ -3788,6 +3780,20 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
try { try {
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
std::string prompt;
if (tmpl && ctx_server.params_base.use_jinja) {
auto chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_data(data, "messages", json::array()),
/* .tools = */ json_data(data, "tools", json()),
/
});
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
} else {
prompt = data.at("prompt").get<std::string>();
}
task.params.chat_parser = common_chat_init()
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size()); tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@ -3800,12 +3806,14 @@ int main(int argc, char ** argv) {
task.params = server_task::params_from_json_cmpl( task.params = server_task::params_from_json_cmpl(
ctx_server.ctx, ctx_server.ctx,
ctx_server.params_base, ctx_server.params_base,
nullptr,
data); data);
task.id_selected_slot = json_value(data, "id_slot", -1); task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat // OAI-compat
task.params.oaicompat = oaicompat; task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
task.params.chat_parser = common_chat_init()
task.params.oaicompat_tools = json_value(data, "tools", json()); task.params.oaicompat_tools = json_value(data, "tools", json());
task.params.oaicompat_tool_call_style = tool_call_style; task.params.oaicompat_tool_call_style = tool_call_style;
// oaicompat_model is already populated by params_from_json_cmpl // oaicompat_model is already populated by params_from_json_cmpl
@ -3983,18 +3991,16 @@ int main(int argc, char ** argv) {
auto body = json::parse(req.body); auto body = json::parse(req.body);
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
auto tool_call_style = common_tool_call_style_detect(chat_template); LOG_INF("Request: %s\n", body.dump(2).c_str());
LOG_INF("Tool call style: %s\n", common_tool_call_style_name(tool_call_style).c_str());
json data = oaicompat_completion_params_parse(body, chat_template, tool_call_style, params.use_jinja); json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
data, data,
req.is_connection_closed, req.is_connection_closed,
res, res,
OAICOMPAT_TYPE_CHAT, OAICOMPAT_TYPE_CHAT);
tool_call_style);
}; };
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {

View file

@ -17,8 +17,8 @@
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include "minja.hpp" #include "minja.hpp"
#include "chat-handler.hpp"
#include "chat-template.hpp" #include "chat-template.hpp"
#include "tool-call.h"
#include <random> #include <random>
#include <sstream> #include <sstream>
@ -581,24 +581,18 @@ static json oaicompat_completion_params_parse(const json & body) {
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
const common_chat_template & tmpl, const common_chat_template & tmpl,
common_tool_call_style tool_call_style,
bool use_jinja) bool use_jinja)
{ {
json llama_params; json llama_params;
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", false); auto stream = json_value(body, "stream", false);
if (has_tools) { if (tools.is_array() && !tools.empty()) {
if (stream) { if (stream) {
throw std::runtime_error("Cannot use tools with stream"); throw std::runtime_error("Cannot use tools with stream");
} }
if (use_jinja) { if (!use_jinja) {
if (tool_call_style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_UNKNOWN) {
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
}
} else {
throw std::runtime_error("tools param requires --jinja flag"); throw std::runtime_error("tools param requires --jinja flag");
} }
} }
@ -627,32 +621,16 @@ static json oaicompat_completion_params_parse(
// Apply chat template to the list of messages // Apply chat template to the list of messages
if (use_jinja) { if (use_jinja) {
bool allow_content = tool_choice != "required";
if (tool_choice != "none" && has_tools) {
llama_params["tools"] = tools; llama_params["tools"] = tools;
llama_params["tool_call_style"] = tool_call_style; auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json(); throw std::runtime_error("Invalid tool_choice: " + tool_choice);
llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = common_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stops) {
llama_params["stop"].push_back(stop);
} }
if (!handler.grammar_triggers.empty()) { llama_params["tool_choice"] = tool_choice;
llama_params["grammar_trigger_words"] = handler.grammar_triggers; llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false);
} if (tool_choice != "none" && llama_params.contains("grammar")) {
if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools."); throw std::runtime_error("Cannot use custom grammar constraints with tools.");
} }
llama_params["grammar"] = handler.grammar;
}
} else {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
}
} else { } else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages")); llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
} }

View file

@ -133,7 +133,7 @@ llama_target_and_test(test-chat-template.cpp)
# llama_target_and_test(test-opt.cpp) # SLOW # llama_target_and_test(test-opt.cpp) # SLOW
llama_target_and_test(test-gguf.cpp) llama_target_and_test(test-gguf.cpp)
llama_target_and_test(test-backend-ops.cpp) llama_target_and_test(test-backend-ops.cpp)
llama_target_and_test(test-tool-call.cpp) llama_target_and_test(test-chat-handler.cpp)
llama_target_and_test(test-model-load-cancel.cpp LABEL "model") llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
llama_target_and_test(test-autorelease.cpp LABEL "model") llama_target_and_test(test-autorelease.cpp LABEL "model")

View file

@ -1,4 +1,4 @@
#include "tool-call.h" #include "chat-handler.hpp"
#include "llama-grammar.h" #include "llama-grammar.h"
#include "unicode.h" #include "unicode.h"
@ -20,6 +20,7 @@ static void assert_equals(const T & expected, const T & actual) {
} }
static std::string read_file(const std::string &path) { static std::string read_file(const std::string &path) {
std::cout << "# Reading: " << path << std::endl << std::flush;
std::ifstream fs(path, std::ios_base::binary); std::ifstream fs(path, std::ios_base::binary);
if (!fs.is_open()) { if (!fs.is_open()) {
fs = std::ifstream("../" + path, std::ios_base::binary); fs = std::ifstream("../" + path, std::ios_base::binary);
@ -76,10 +77,16 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool
assert_equals(expected_content, result.content); assert_equals(expected_content, result.content);
auto tool_calls = json::array(); auto tool_calls = json::array();
for (const auto & tc : result.tool_calls) { for (const auto & tc : result.tool_calls) {
auto arguments = tc.arguments;
try {
arguments = dump(json::parse(arguments));
} catch (const std::exception & e) {
// ignore
}
auto tool_call = json { auto tool_call = json {
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
{"arguments", dump(json::parse(tc.arguments))}, {"arguments", arguments},
{"name", tc.name}, {"name", tc.name},
}}, }},
}; };
@ -94,8 +101,7 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool
assert_equals(expected, actual); assert_equals(expected, actual);
} }
const json tools = json::parse(R"([ const auto special_function_tool = json::parse(R"({
{
"type": "function", "type": "function",
"function": { "function": {
"name": "special_function", "name": "special_function",
@ -111,25 +117,28 @@ const json tools = json::parse(R"([
"required": ["arg1"] "required": ["arg1"]
} }
} }
}, })");
{ const auto python_tool = json::parse(R"({
"type": "function", "type": "function",
"function": { "function": {
"name": "python", "name": "python",
"description": "a python interpreter", "description": "an ipython interpreter",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"code": { "code": {
"type": "string", "type": "string",
"description": "The code." "description": "Python code to execute."
} }
}, },
"required": ["code"] "required": ["code"]
} }
} }
} })");
])"); const auto code_interpreter_tool = json::parse(R"({
"type": "code_interpreter"
})");
const json tools = {special_function_tool, code_interpreter_tool};
static void test_parsing() { static void test_parsing() {
json request = { json request = {
@ -226,9 +235,7 @@ static void test_parsing() {
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
{"name", "python"}, {"name", "python"},
{"arguments", dump({ {"arguments", "this could be anything"},
{"code", "this could be anything"}
})}
}} }}
}}); }});
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
@ -238,7 +245,7 @@ static void test_parsing() {
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
{"name", "python"}, {"name", "python"},
{"arguments", dump({{"code", ""}})} {"arguments", ""},
}} }}
}}); }});
auto special_function_call = json { auto special_function_call = json {
@ -332,6 +339,8 @@ static void test_tool_call_style_detection() {
} }
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) { static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
@ -354,9 +363,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
return delta; return delta;
} }
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
const common_chat_template tmpl(read_file(template_file), bos_token, eos_token);
auto tool_call_style = common_tool_call_style_detect(tmpl); auto tool_call_style = common_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls"); auto & tool_calls = tool_calling_message.at("tool_calls");
@ -404,24 +411,77 @@ static void test_grammars() {
auto tool_call_message_with_id = json::parse(tool_call_message.dump()); auto tool_call_message_with_id = json::parse(tool_call_message.dump());
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789"; tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools, auto python_tool_call_message = json {
/* skip_grammar_test= */ true); {"role", "assistant"},
test_template("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools); {"content", ""},
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools); {"tool_calls", json {{
test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools); {"type", "function"},
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); {"function", {
test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); {"name", "python"},
test_template("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); {"arguments", "print('hey')"}
test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); }},
test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); }}}
test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "<s>", "</s>", { "<|eot_id|>" }, tool_call_message, tools); };
test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "<s>", "</s>", { "<end_of_turn>" }, tool_call_message_with_id, tools);
test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "<s>", "</s>", { "<|end|>" }, tool_call_message_with_id, tools); // {
// const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
// test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
// assert_equals(tmpl.requires_object_arguments_, true);
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
{
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
test_template(tmpl, { "<end_of_turn>" }, tool_call_message_with_id, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
}
} }
int main() { int main() {
test_tool_call_style_detection(); // test_tool_call_style_detection();
test_parsing(); // test_parsing();
test_grammars(); test_grammars();
std::cout << "\n[tool-call] All tests passed!" << std::endl; std::cout << "\n[tool-call] All tests passed!" << std::endl;