split chat handler vs. parser around enum again

This commit is contained in:
ochafik 2025-01-30 04:11:05 +00:00
parent 590c97931a
commit f8e14bffc3
5 changed files with 501 additions and 427 deletions

View file

@ -4,6 +4,23 @@
#include "log.h"
#include "minja.hpp"
std::string common_chat_format_name(common_chat_format format) {
switch (format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
default:
throw std::runtime_error("Unknown chat format");
}
}
const common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ false,
@ -55,25 +72,21 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
}
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::optional<std::regex> & trigger_opt, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
static common_chat_msg parse_json_tool_calls(
const std::string& input,
const std::optional<std::regex> & trigger_opt,
const std::regex & function_regex,
const std::regex & close_regex) {
std::smatch match;
common_chat_msg result;
result.role = "assistant";
std::vector<std::string> tool_names;
if (check_names) {
for (const auto & tool : tools) {
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
continue;
}
tool_names.push_back(tool["function"]["name"]);
}
}
auto end = input.end();
auto it = input.begin();
@ -96,24 +109,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
break;
}
auto name = rit->str(1);
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
result.content += std::string(it, rit->suffix().first);
it = rit->suffix().first;
continue;
}
result.content += std::string(it, rit->prefix().second);
it = rit->suffix().first;
json arguments;
if (!parse_json(it, end, arguments)) {
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""});
break;
}
throw std::runtime_error("Failed to parse json tool call arguments");
}
if (!std::regex_search(it, end, match, close_regex)) {
@ -162,15 +162,7 @@ static void foreach_function(const json & tools, const std::function<void(const
}
}
static common_chat_msg no_op_text_parser(const std::string & input) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
auto tool_call_schemas = json::array();
@ -252,35 +244,35 @@ static common_chat_params common_chat_params_init_generic_tool_call(const common
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "generic tool calls";
data.parser = [&](const std::string & input) {
json data = json::parse(input);
common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({
tool_call["name"],
tool_call["arguments"].dump(),
tool_call.contains("id") ? tool_call["id"] : "",
});
}
} else if (data.contains("tool_call")) {
result.tool_calls.push_back({
data["tool_call"]["name"],
data["tool_call"]["arguments"].dump(),
/* id= */ "",
});
} else if (data.contains("response")) {
const auto & response = data["response"];
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
};
data.format = COMMON_CHAT_FORMAT_GENERIC;
return data;
}
static common_chat_msg common_chat_parse_generic(const std::string & input) {
json data = json::parse(input);
common_chat_msg result;
result.role = "assistant";
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({
tool_call["name"],
tool_call["arguments"].dump(),
tool_call.contains("id") ? tool_call["id"] : "",
});
}
} else if (data.contains("tool_call")) {
result.tool_calls.push_back({
data["tool_call"]["name"],
data["tool_call"]["arguments"].dump(),
/* id= */ "",
});
} else if (data.contains("response")) {
const auto & response = data["response"];
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
}
static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -318,12 +310,12 @@ static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const c
}, grammar_options);
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "mistral nemo tool calls";
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
};
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
return data;
}
static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
@ -379,7 +371,6 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
return true;
};
auto has_function = false;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
@ -411,45 +402,48 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
{"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
});
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
if (allow_python_tag_builtin_tools && !builtin_tools.empty()) {
std::smatch match;
if (std::regex_match(input, match, builtin_call_regex)) {
auto name = match[1].str();
auto raw_args = match[2].str();
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
auto it_eq = raw_args.find('=');
auto arg_name = raw_args.substr(0, it_eq);
auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str);
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
},
},
};
}
}
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
};
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
return data;
}
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
// TODO: tighten & simplify the parser, don't accept leading text context.
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
if (with_builtin_tools) {
std::smatch match;
if (std::regex_match(input, match, builtin_call_regex)) {
auto name = match[1].str();
auto raw_args = match[2].str();
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
auto it_eq = raw_args.find('=');
auto arg_name = raw_args.substr(0, it_eq);
auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str);
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
},
},
};
}
}
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -466,19 +460,23 @@ static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const co
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "deepseek r1 tool calls";
data.parser = [inputs](const std::string & input) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
};
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
return data;
}
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
fprintf(stderr, "%s\n", __func__);
common_chat_params data;
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}, /* adjust_inputs= */ false);
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -508,26 +506,22 @@ static common_chat_params common_chat_params_init_firefunction_v2_tool_call(cons
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.parser = [](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
};
data.format = "firefunction v2 tool calls";
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
} else {
data.parser = no_op_text_parser;
data.format = "firefunction v2 text-only";
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
}
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}, /* adjust_inputs= */ false);
return data;
}
static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_params data;
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -552,26 +546,52 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
}
}, grammar_options);
data.format = "functionary v3.2 tool calls";
} else {
data.format = "functionary v3.2 content-only";
}
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.parser = [inputs](const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
if (res.content.find("all\n") == 0) {
res.content = res.content.substr(4);
}
return res;
};
return data;
}
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
auto expected_it = expected.begin();
auto tmp_it = it;
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
++tmp_it;
++expected_it;
}
if (expected_it == expected.end()) {
it = tmp_it;
return true;
}
return false;
}
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
std::string content;
auto it = input.begin();
const auto end = input.end();
if (consume(it, end, "all\n")) {
std::smatch match;
if (std::regex_search(it, end, match, function_regex)) {
auto fun_it = match.prefix().second;
content = std::string(it, fun_it);
it = fun_it;
} else {
common_chat_msg res;
res.role = "assistant";
res.content = std::string(it, end);
return res;
}
}
// TODO: tighten & simplify.
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
res.content = content;
return res;
}
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
@ -620,33 +640,35 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_too
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "functionary v3.1 llama 3.1 tool calls";
data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str();
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(),
/* .id = */ "",
},
}
};
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
};
// TODO: if (has_raw_python)
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
return data;
}
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str();
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ (json {{"code", code}}).dump(),
/* .id = */ "",
},
}
};
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
// TODO: tighten & simplify.
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = inputs.tool_choice != "required";
@ -672,69 +694,68 @@ static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const c
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "hermes 2 pro tool calls";
data.parser = [&](const std::string & input) -> common_chat_msg {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
return data;
}
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
json call;
if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call["arguments"];
result.tool_calls.push_back({
call["name"],
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
/* id= */ "",
});
rit = {it, end, middle_pattern};
if (rit != rend) {
it = rit->suffix().first;
} else {
rit = {it, end, end_pattern};
if (rit == rend) {
throw std::runtime_error("Malformed input, missing </tool_call>");
}
break;
}
}
return result;
} catch (const std::exception & e) {
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
};
return data;
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
json call;
if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call["arguments"];
result.tool_calls.push_back({
call["name"],
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
/* id= */ "",
});
rit = {it, end, middle_pattern};
if (rit != rend) {
it = rit->suffix().first;
} else {
rit = {it, end, end_pattern};
if (rit == rend) {
throw std::runtime_error("Malformed input, missing </tool_call>");
}
break;
}
}
return result;
} catch (const std::exception & e) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = "content-only";
data.parser = no_op_text_parser;
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
data.grammar_lazy = false;
if (!inputs.json_schema.is_null()) {
if (!inputs.grammar.empty()) {
@ -748,7 +769,7 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
}
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none";
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
if (has_tools && !inputs.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
@ -760,30 +781,64 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
}
if (src.find(" functools[") != std::string::npos) {
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs);
return common_chat_params_init_firefunction_v2(tmpl, inputs);
}
if (has_tools) {
if (!has_tools) {
return common_chat_params_init_without_tools(tmpl, inputs);
}
if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs);
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
}
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(tmpl, inputs);
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
}
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
}
if (src.find("<tool▁calls▁begin>") != std::string::npos) {
return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs);
return common_chat_params_init_deepseek_r1(tmpl, inputs);
}
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs);
return common_chat_params_init_mistral_nemo(tmpl, inputs);
}
return common_chat_params_init_generic_tool_call(tmpl, inputs);
return common_chat_params_init_generic(tmpl, inputs);
}
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
switch (format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
return common_chat_parse_content_only(input);
case COMMON_CHAT_FORMAT_GENERIC:
return common_chat_parse_generic(input);
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
return common_chat_parse_mistral_nemo(input);
case COMMON_CHAT_FORMAT_LLAMA_3_X:
return common_chat_parse_llama_3_1(input);
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
return common_chat_parse_deepseek_r1(input);
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
return common_chat_parse_functionary_v3_2(input);
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
return common_chat_parse_functionary_v3_1_llama_3_1(input);
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
return common_chat_parse_hermes_2_pro(input);
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input);
default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
}
}

View file

@ -21,16 +21,30 @@ struct common_chat_inputs {
bool add_generation_prompt = true;
};
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_params {
json prompt;
std::string grammar;
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> additional_stops;// std::unique_ptr<class common_chat_parser> parser;
common_chat_parser parser;
std::string format; // For debugging and testing.
bool grammar_lazy = false;
std::vector<std::string> additional_stops;
};
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);

View file

@ -117,7 +117,7 @@ struct slot_params {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_parser chat_parser;
common_chat_format oaicompat_chat_format;
json to_json() const {
std::vector<std::string> samplers;
@ -321,27 +321,51 @@ struct server_task {
}
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
}
{
params.antiprompt.clear();
const auto stop = data.find("stop");
if (stop != data.end()) {
params.antiprompt = *stop;
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<u_int32_t>());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
}
}
if (!params_base.use_jinja) {
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
try {
auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
{
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
common_grammar_trigger trigger;
trigger.word = t.at("word");
trigger.at_start = t.at("at_start");
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
}
}
@ -380,6 +404,19 @@ struct server_task {
}
}
{
params.antiprompt.clear();
const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
}
}
}
{
const auto samplers = data.find("samplers");
if (samplers != data.end()) {
@ -533,7 +570,7 @@ struct completion_token_output {
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
common_chat_msg message;
std::string content;
llama_tokens tokens;
bool stream;
@ -559,6 +596,7 @@ struct server_task_result_cmpl_final : server_task_result {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
virtual int get_index() override {
return index;
@ -584,7 +622,7 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_non_oaicompat() {
json res = json {
{"index", index},
{"content", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"tokens", stream ? llama_tokens {} : tokens},
{"id_slot", id_slot},
{"stop", true},
@ -621,7 +659,7 @@ struct server_task_result_cmpl_final : server_task_result {
json res = json {
{"choices", json::array({
json{
{"text", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
@ -652,8 +690,12 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg message;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
message = common_chat_parse(content, oaicompat_chat_format);
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
} else {
message.content = content;
}
json tool_calls;
@ -1189,7 +1231,6 @@ struct server_slot {
std::string stopping_word;
// sampling
json json_schema;
@ -1197,7 +1238,7 @@ struct server_slot {
llama_token sampled;
common_chat_parser chat_parser;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
// stats
size_t n_sent_text = 0; // number of sent text character
@ -2282,10 +2323,11 @@ struct server_context {
res->id_slot = slot.id;
res->index = slot.index;
res->tokens = slot.generated_tokens;
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->response_fields = slot.params.response_fields;
res->response_fields = std::move(slot.params.response_fields);
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
@ -2296,21 +2338,12 @@ struct server_context {
res->stop = slot.stop;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
if (slot.params.chat_parser) {
LOG_DBG("Raw chat output: %s\n", slot.generated_text.c_str());
res->message = slot.params.chat_parser(slot.generated_text);
} else {
res->message = {
/* .role = */ "assistant",
/* .content = */ std::move(slot.generated_text),
/* .tool_calls = */ {}
};
}
res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
@ -3773,8 +3806,7 @@ int main(int argc, char ** argv) {
json & data,
std::function<bool()> is_connection_closed,
httplib::Response & res,
oaicompat_type oaicompat,
const common_chat_template * tmpl = nullptr) {
oaicompat_type oaicompat) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) {
@ -3786,40 +3818,7 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
try {
common_chat_params chat_params;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) {
chat_params = common_chat_params_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()),
/* .tools = */ json_value(data, "tools", json()),
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
/* .json_schema = */ json_value(data, "json_schema", json()),
/* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false),
/* .stream = */ json_value(data, "stream", false),
/* .grammar = */ json_value(data, "grammar", std::string("")),
});
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
if (data.contains("grammar")) {
if (!chat_params.grammar.empty()) {
throw std::runtime_error("Cannot provide grammar and tools");
}
chat_params.grammar = data.at("grammar");
}
// TODO: move inside minja:chat_template?
add_special = tmpl->source().find("eos_token") == std::string::npos &&
tmpl->source().find("bos_token") == std::string::npos;
} else {
add_special = true;
chat_params.prompt = data.at("prompt");
if (data.contains("grammar")) {
chat_params.grammar = data.at("grammar");
} else if (data.contains("json_schema")) {
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
@ -3837,25 +3836,6 @@ int main(int argc, char ** argv) {
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
// Grammar & tool-calls
task.params.sampling.grammar = chat_params.grammar;
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
for (const auto & trigger : chat_params.grammar_triggers) {
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
task.params.sampling.grammar_trigger_words.push_back(trigger);
}
task.params.antiprompt = chat_params.additional_stops;
task.params.chat_parser = chat_params.parser;
if (task.params.sampling.grammar_lazy) {
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
}
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task);
@ -4039,8 +4019,7 @@ int main(int argc, char ** argv) {
data,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_CHAT,
&chat_template);
OAICOMPAT_TYPE_CHAT);
};
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {

View file

@ -630,11 +630,34 @@ static json oaicompat_completion_params_parse(
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
llama_params["tool_choice"] = tool_choice;
llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false);
if (tool_choice != "none" && llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
common_chat_inputs inputs;
inputs.messages = body.at("messages");
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.stream = stream;
// TODO: support mixing schema w/ tools beyond generic format.
inputs.json_schema = json_value(llama_params, "json_schema", json::object());
auto chat_params = common_chat_params_init(tmpl, inputs);
llama_params["chat_format"] = static_cast<u_int32_t>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back({
{"word", trigger.word},
{"at_start", trigger.at_start},
});
}
llama_params["grammar_triggers"] = grammar_triggers;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
} else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
}

View file

@ -170,9 +170,9 @@ const json tools = {special_function_tool, python_tool};
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
struct delta_data {
std::string delta;
std::string grammar;
common_chat_parser parser;
std::string delta;
std::string grammar;
common_chat_format format;
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
@ -212,7 +212,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
break;
}
}
return {delta, params_full.grammar, params_full.parser};
return {delta, params_full.grammar, params_full.format};
}
/*
@ -235,7 +235,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
}
if (!skip_parser_test) {
const auto msg = data.parser(data.delta);
const auto msg = common_chat_parse(data.delta, data.format);
assert_msg_equals(expected_msg, msg);
}
@ -257,10 +257,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
}
}
static std::string describe(const common_chat_template & tmpl, const common_chat_inputs & params) {
return common_chat_params_init(tmpl, params).format;
}
static void test_template_output_parsers() {
auto text_message = json {
{"role", "assistant"},
@ -315,117 +311,124 @@ static void test_template_output_parsers() {
inputs_tools.tools = json::array();
inputs_tools.tools.push_back(special_function_tool);
{
const common_chat_template tmpl(read_file(
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end_of_turn>" };
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
inputs_tools_builtin.tools = json::array();
inputs_tools_builtin.tools.push_back(python_tool);
assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file(
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
// {
// const common_chat_template tmpl(read_file(
// "models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
// std::vector<std::string> end_tokens { "<end_of_turn>" };
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser(
"{\n"
" \"response\": \"Hello, world!\"\n"
"}"));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}");
}
{
const common_chat_template tmpl(read_file(
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "</s>" };
// assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(common_chat_template(read_file(
// "models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
// // Generic tool calls doesn't generate / parse content-only messages symmetrically.
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
}
{
const common_chat_template tmpl(read_file(
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
// assert_msg_equals(msg_from_json(text_message), common_chat_parse(
// "{\n"
// " \"response\": \"Hello, world!\"\n"
// "}",
// common_chat_params_init(tmpl, inputs_tools).format));
// test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
// "{\n"
// " \"tool_calls\": [\n"
// " {\n"
// " \"name\": \"special_function\",\n"
// " \"arguments\": {\n"
// " \"arg1\": 1\n"
// " },\n"
// " \"id\": \"123456789\"\n"
// " }\n"
// " ]\n"
// "}");
// }
// {
// const common_chat_template tmpl(read_file(
// "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
// std::vector<std::string> end_tokens { "</s>" };
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file(
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
// test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
// "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
// /* skip_grammar_test= */ true);
// }
// {
// const common_chat_template tmpl(read_file(
// "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
// std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file(
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(common_chat_template(read_file(
// "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(common_chat_template(read_file(
// "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
"<|python_tag|>python.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file(
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
// test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, tool_call_message, tools,
// "<tool_call>\n"
// "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
// "</tool_call>");
// test_template(tmpl, end_tokens, python_tool_call_message, tools,
// "<tool_call>\n"
// "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
// "</tool_call>");
// }
// {
// const common_chat_template tmpl(read_file(
// "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, common_chat_params_init(tmpl, inputs_tools_builtin).format);
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, common_chat_params_init(common_chat_template(read_file(
// "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools_builtin).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file(
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
// // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
// "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
// test_template(tmpl, end_tokens, python_tool_call_message, tools,
// "<|python_tag|>python.call(code=\"print('hey')\")");
// test_template(tmpl, end_tokens, tool_call_message, tools,
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
// }
// {
// const common_chat_template tmpl(read_file(
// "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
// test_template(tmpl, end_tokens, text_message, tools,
// "Hello, world!", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, tool_call_message, tools,
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
// }
// {
// const common_chat_template tmpl(read_file(
// "models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
// assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, common_chat_params_init(tmpl, inputs_tools).format);
// test_template(tmpl, end_tokens, text_message, tools,
// "Hello, world!", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, tool_call_message, tools,
// "<function=special_function>{\"arg1\": 1}</function>");
// }
{
const common_chat_template tmpl(read_file(
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
test_template(tmpl, end_tokens, text_message, {},
"all\n"
"Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
@ -437,7 +440,7 @@ static void test_template_output_parsers() {
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eot_id|>" };
assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
@ -449,7 +452,7 @@ static void test_template_output_parsers() {
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end▁of▁sentence>" };
assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools,
"Hello, world!", /* skip_grammar_test= */ true);
@ -480,7 +483,7 @@ int main(int argc, char **argv) {
common_chat_template tmpl(read_file(path), "", "");
auto parts = string_split(path, "/");
auto name = parts[parts.size() - 1];
std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n";
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format) << " |\n";
}
}
else