split chat handler vs. parser around enum again
This commit is contained in:
parent
590c97931a
commit
f8e14bffc3
5 changed files with 501 additions and 427 deletions
489
common/chat.cpp
489
common/chat.cpp
|
@ -4,6 +4,23 @@
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "minja.hpp"
|
#include "minja.hpp"
|
||||||
|
|
||||||
|
std::string common_chat_format_name(common_chat_format format) {
|
||||||
|
switch (format) {
|
||||||
|
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
||||||
|
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
|
||||||
|
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
|
||||||
|
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||||
|
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||||
|
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||||
|
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||||
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||||
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||||
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unknown chat format");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const common_grammar_options grammar_options {
|
const common_grammar_options grammar_options {
|
||||||
/* .dotall = */ false,
|
/* .dotall = */ false,
|
||||||
/* .compact_spaces = */ false,
|
/* .compact_spaces = */ false,
|
||||||
|
@ -55,25 +72,21 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||||
* Aggregates the prefix, suffix and in-between text into the content.
|
* Aggregates the prefix, suffix and in-between text into the content.
|
||||||
*/
|
*/
|
||||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::optional<std::regex> & trigger_opt, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
|
static common_chat_msg parse_json_tool_calls(
|
||||||
|
const std::string& input,
|
||||||
|
const std::optional<std::regex> & trigger_opt,
|
||||||
|
const std::regex & function_regex,
|
||||||
|
const std::regex & close_regex) {
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
|
||||||
common_chat_msg result;
|
common_chat_msg result;
|
||||||
result.role = "assistant";
|
result.role = "assistant";
|
||||||
|
|
||||||
std::vector<std::string> tool_names;
|
|
||||||
if (check_names) {
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
tool_names.push_back(tool["function"]["name"]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto end = input.end();
|
auto end = input.end();
|
||||||
auto it = input.begin();
|
auto it = input.begin();
|
||||||
|
@ -96,24 +109,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto name = rit->str(1);
|
auto name = rit->str(1);
|
||||||
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
|
|
||||||
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
|
|
||||||
result.content += std::string(it, rit->suffix().first);
|
|
||||||
it = rit->suffix().first;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.content += std::string(it, rit->prefix().second);
|
result.content += std::string(it, rit->prefix().second);
|
||||||
it = rit->suffix().first;
|
it = rit->suffix().first;
|
||||||
|
|
||||||
|
|
||||||
json arguments;
|
json arguments;
|
||||||
if (!parse_json(it, end, arguments)) {
|
if (!parse_json(it, end, arguments)) {
|
||||||
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
|
|
||||||
std::string src(it, end);
|
|
||||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||||
}
|
}
|
||||||
if (!std::regex_search(it, end, match, close_regex)) {
|
if (!std::regex_search(it, end, match, close_regex)) {
|
||||||
|
@ -162,15 +162,7 @@ static void foreach_function(const json & tools, const std::function<void(const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_msg no_op_text_parser(const std::string & input) {
|
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ input,
|
|
||||||
/* .tool_calls = */ {},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
|
@ -252,35 +244,35 @@ static common_chat_params common_chat_params_init_generic_tool_call(const common
|
||||||
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
||||||
|
|
||||||
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "generic tool calls";
|
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||||
data.parser = [&](const std::string & input) {
|
|
||||||
json data = json::parse(input);
|
|
||||||
common_chat_msg result;
|
|
||||||
result.role = "assistant";
|
|
||||||
if (data.contains("tool_calls")) {
|
|
||||||
for (const auto & tool_call : data["tool_calls"]) {
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
tool_call["name"],
|
|
||||||
tool_call["arguments"].dump(),
|
|
||||||
tool_call.contains("id") ? tool_call["id"] : "",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (data.contains("tool_call")) {
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
data["tool_call"]["name"],
|
|
||||||
data["tool_call"]["arguments"].dump(),
|
|
||||||
/* id= */ "",
|
|
||||||
});
|
|
||||||
} else if (data.contains("response")) {
|
|
||||||
const auto & response = data["response"];
|
|
||||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
||||||
|
json data = json::parse(input);
|
||||||
|
common_chat_msg result;
|
||||||
|
result.role = "assistant";
|
||||||
|
if (data.contains("tool_calls")) {
|
||||||
|
for (const auto & tool_call : data["tool_calls"]) {
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
tool_call["name"],
|
||||||
|
tool_call["arguments"].dump(),
|
||||||
|
tool_call.contains("id") ? tool_call["id"] : "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (data.contains("tool_call")) {
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
data["tool_call"]["name"],
|
||||||
|
data["tool_call"]["arguments"].dump(),
|
||||||
|
/* id= */ "",
|
||||||
|
});
|
||||||
|
} else if (data.contains("response")) {
|
||||||
|
const auto & response = data["response"];
|
||||||
|
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -318,12 +310,12 @@ static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const c
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "mistral nemo tool calls";
|
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||||
data.parser = [](const std::string & input) {
|
|
||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
|
||||||
};
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
|
||||||
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
|
}
|
||||||
|
|
||||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||||
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||||
|
@ -379,7 +371,6 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto has_function = false;
|
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
|
@ -411,45 +402,48 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||||
{"tools_in_user_message", false},
|
{"tools_in_user_message", false},
|
||||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||||
});
|
});
|
||||||
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
|
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||||
data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||||
static std::regex close_regex("\\}");
|
|
||||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
|
||||||
|
|
||||||
if (allow_python_tag_builtin_tools && !builtin_tools.empty()) {
|
|
||||||
std::smatch match;
|
|
||||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
|
||||||
auto name = match[1].str();
|
|
||||||
auto raw_args = match[2].str();
|
|
||||||
|
|
||||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
|
||||||
auto it_eq = raw_args.find('=');
|
|
||||||
auto arg_name = raw_args.substr(0, it_eq);
|
|
||||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
|
||||||
auto arg_value = json::parse(arg_value_str);
|
|
||||||
|
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ match.prefix().str(),
|
|
||||||
/* .tool_calls = */ {
|
|
||||||
{
|
|
||||||
/* .name = */ match[1],
|
|
||||||
/* .arguments = */ (json {
|
|
||||||
{arg_name, arg_value},
|
|
||||||
}).dump(),
|
|
||||||
/* .id = */ "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
|
||||||
};
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||||
|
// TODO: tighten & simplify the parser, don't accept leading text context.
|
||||||
|
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||||
|
static std::regex close_regex("\\}");
|
||||||
|
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
if (with_builtin_tools) {
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||||
|
auto name = match[1].str();
|
||||||
|
auto raw_args = match[2].str();
|
||||||
|
|
||||||
|
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||||
|
auto it_eq = raw_args.find('=');
|
||||||
|
auto arg_name = raw_args.substr(0, it_eq);
|
||||||
|
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||||
|
auto arg_value = json::parse(arg_value_str);
|
||||||
|
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ match.prefix().str(),
|
||||||
|
/* .tool_calls = */ {
|
||||||
|
{
|
||||||
|
/* .name = */ match[1],
|
||||||
|
/* .arguments = */ (json {
|
||||||
|
{arg_name, arg_value},
|
||||||
|
}).dump(),
|
||||||
|
/* .id = */ "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -466,19 +460,23 @@ static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const co
|
||||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "deepseek r1 tool calls";
|
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||||
data.parser = [inputs](const std::string & input) {
|
|
||||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
|
||||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
|
||||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
|
||||||
return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
|
||||||
};
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
||||||
|
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||||
|
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||||
|
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||||
|
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
fprintf(stderr, "%s\n", __func__);
|
fprintf(stderr, "%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||||
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||||
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||||
|
}, /* adjust_inputs= */ false);
|
||||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -508,26 +506,22 @@ static common_chat_params common_chat_params_init_firefunction_v2_tool_call(cons
|
||||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||||
data.parser = [](const std::string & input) {
|
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
||||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
|
||||||
};
|
|
||||||
data.format = "firefunction v2 tool calls";
|
|
||||||
} else {
|
} else {
|
||||||
data.parser = no_op_text_parser;
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
data.format = "firefunction v2 text-only";
|
|
||||||
}
|
}
|
||||||
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
|
||||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
|
||||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
|
||||||
}, /* adjust_inputs= */ false);
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
|
||||||
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -552,26 +546,52 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||||
}
|
}
|
||||||
|
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.format = "functionary v3.2 tool calls";
|
|
||||||
} else {
|
|
||||||
data.format = "functionary v3.2 content-only";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
||||||
data.parser = [inputs](const std::string & input) {
|
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
|
||||||
|
|
||||||
auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
|
||||||
if (res.content.find("all\n") == 0) {
|
|
||||||
res.content = res.content.substr(4);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
};
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||||
|
auto expected_it = expected.begin();
|
||||||
|
auto tmp_it = it;
|
||||||
|
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||||
|
++tmp_it;
|
||||||
|
++expected_it;
|
||||||
|
}
|
||||||
|
if (expected_it == expected.end()) {
|
||||||
|
it = tmp_it;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
||||||
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
|
|
||||||
|
std::string content;
|
||||||
|
auto it = input.begin();
|
||||||
|
const auto end = input.end();
|
||||||
|
|
||||||
|
if (consume(it, end, "all\n")) {
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(it, end, match, function_regex)) {
|
||||||
|
auto fun_it = match.prefix().second;
|
||||||
|
content = std::string(it, fun_it);
|
||||||
|
it = fun_it;
|
||||||
|
} else {
|
||||||
|
common_chat_msg res;
|
||||||
|
res.role = "assistant";
|
||||||
|
res.content = std::string(it, end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: tighten & simplify.
|
||||||
|
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
||||||
|
res.content = content;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||||
|
@ -620,33 +640,35 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_too
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "functionary v3.1 llama 3.1 tool calls";
|
// TODO: if (has_raw_python)
|
||||||
data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
|
||||||
std::smatch match;
|
|
||||||
if (std::regex_search(input, match, python_tag_regex)) {
|
|
||||||
auto code = match[1].str();
|
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ match.prefix().str(),
|
|
||||||
/* .tool_calls = */ {
|
|
||||||
{
|
|
||||||
/* .name = */ "python",
|
|
||||||
/* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(),
|
|
||||||
/* .id = */ "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
|
||||||
static std::regex close_regex(R"(</function>)");
|
|
||||||
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
|
||||||
};
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||||
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||||
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(input, match, python_tag_regex)) {
|
||||||
|
auto code = match[1].str();
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ match.prefix().str(),
|
||||||
|
/* .tool_calls = */ {
|
||||||
|
{
|
||||||
|
/* .name = */ "python",
|
||||||
|
/* .arguments = */ (json {{"code", code}}).dump(),
|
||||||
|
/* .id = */ "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
|
static std::regex close_regex(R"(</function>)");
|
||||||
|
// TODO: tighten & simplify.
|
||||||
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
|
@ -672,69 +694,68 @@ static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const c
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "hermes 2 pro tool calls";
|
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||||
data.parser = [&](const std::string & input) -> common_chat_msg {
|
return data;
|
||||||
try {
|
}
|
||||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
|
||||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
try {
|
||||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||||
|
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||||
|
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||||
|
|
||||||
auto end = input.end();
|
auto end = input.end();
|
||||||
std::sregex_iterator rend;
|
std::sregex_iterator rend;
|
||||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||||
if (rit == rend) {
|
if (rit == rend) {
|
||||||
return {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ input,
|
|
||||||
/* .tool_calls = */ {},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
common_chat_msg result;
|
|
||||||
result.role = "assistant";
|
|
||||||
result.content = rit->prefix();
|
|
||||||
|
|
||||||
auto it = rit->suffix().first;
|
|
||||||
while (it != end) {
|
|
||||||
json call;
|
|
||||||
if (!parse_json(it, end, call)) {
|
|
||||||
throw std::runtime_error("Failed to parse json tool call");
|
|
||||||
}
|
|
||||||
const auto & arguments = call["arguments"];
|
|
||||||
result.tool_calls.push_back({
|
|
||||||
call["name"],
|
|
||||||
arguments.dump(),
|
|
||||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
||||||
/* id= */ "",
|
|
||||||
});
|
|
||||||
rit = {it, end, middle_pattern};
|
|
||||||
if (rit != rend) {
|
|
||||||
it = rit->suffix().first;
|
|
||||||
} else {
|
|
||||||
rit = {it, end, end_pattern};
|
|
||||||
if (rit == rend) {
|
|
||||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
return {
|
return {
|
||||||
/* .role = */ "assistant",
|
/* .role = */ "assistant",
|
||||||
/* .content = */ input,
|
/* .content = */ input,
|
||||||
/* .tool_calls = */ {},
|
/* .tool_calls = */ {},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
|
||||||
return data;
|
common_chat_msg result;
|
||||||
|
result.role = "assistant";
|
||||||
|
result.content = rit->prefix();
|
||||||
|
|
||||||
|
auto it = rit->suffix().first;
|
||||||
|
while (it != end) {
|
||||||
|
json call;
|
||||||
|
if (!parse_json(it, end, call)) {
|
||||||
|
throw std::runtime_error("Failed to parse json tool call");
|
||||||
|
}
|
||||||
|
const auto & arguments = call["arguments"];
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
call["name"],
|
||||||
|
arguments.dump(),
|
||||||
|
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
|
/* id= */ "",
|
||||||
|
});
|
||||||
|
rit = {it, end, middle_pattern};
|
||||||
|
if (rit != rend) {
|
||||||
|
it = rit->suffix().first;
|
||||||
|
} else {
|
||||||
|
rit = {it, end, end_pattern};
|
||||||
|
if (rit == rend) {
|
||||||
|
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ input,
|
||||||
|
/* .tool_calls = */ {},
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = "content-only";
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
data.parser = no_op_text_parser;
|
|
||||||
data.grammar_lazy = false;
|
data.grammar_lazy = false;
|
||||||
if (!inputs.json_schema.is_null()) {
|
if (!inputs.json_schema.is_null()) {
|
||||||
if (!inputs.grammar.empty()) {
|
if (!inputs.grammar.empty()) {
|
||||||
|
@ -748,7 +769,7 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none";
|
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
|
||||||
if (has_tools && !inputs.grammar.empty()) {
|
if (has_tools && !inputs.grammar.empty()) {
|
||||||
throw std::runtime_error("Cannot specify grammar with tools");
|
throw std::runtime_error("Cannot specify grammar with tools");
|
||||||
}
|
}
|
||||||
|
@ -760,30 +781,64 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
||||||
}
|
}
|
||||||
if (src.find(" functools[") != std::string::npos) {
|
if (src.find(" functools[") != std::string::npos) {
|
||||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||||
return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs);
|
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_tools) {
|
if (!has_tools) {
|
||||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src.find("<tool_call>") != std::string::npos) {
|
if (src.find("<tool_call>") != std::string::npos) {
|
||||||
return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs);
|
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>") != std::string::npos
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
&& src.find("<function=") != std::string::npos) {
|
&& src.find("<function=") != std::string::npos) {
|
||||||
return common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(tmpl, inputs);
|
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||||
}
|
}
|
||||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||||
return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs);
|
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||||
}
|
}
|
||||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs);
|
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||||
}
|
}
|
||||||
return common_chat_params_init_generic_tool_call(tmpl, inputs);
|
return common_chat_params_init_generic(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
||||||
|
return {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ input,
|
||||||
|
/* .tool_calls = */ {},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
||||||
|
switch (format) {
|
||||||
|
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||||
|
return common_chat_parse_content_only(input);
|
||||||
|
case COMMON_CHAT_FORMAT_GENERIC:
|
||||||
|
return common_chat_parse_generic(input);
|
||||||
|
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||||
|
return common_chat_parse_mistral_nemo(input);
|
||||||
|
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||||
|
return common_chat_parse_llama_3_1(input);
|
||||||
|
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||||
|
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
||||||
|
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||||
|
return common_chat_parse_deepseek_r1(input);
|
||||||
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||||
|
return common_chat_parse_functionary_v3_2(input);
|
||||||
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||||
|
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
||||||
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||||
|
return common_chat_parse_hermes_2_pro(input);
|
||||||
|
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||||
|
return common_chat_parse_firefunction_v2(input);
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,16 +21,30 @@ struct common_chat_inputs {
|
||||||
bool add_generation_prompt = true;
|
bool add_generation_prompt = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
|
enum common_chat_format {
|
||||||
|
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
COMMON_CHAT_FORMAT_GENERIC,
|
||||||
|
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||||
|
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||||
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||||
|
|
||||||
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
|
};
|
||||||
|
|
||||||
struct common_chat_params {
|
struct common_chat_params {
|
||||||
json prompt;
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
std::string grammar;
|
json prompt;
|
||||||
|
std::string grammar;
|
||||||
|
bool grammar_lazy = false;
|
||||||
std::vector<common_grammar_trigger> grammar_triggers;
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
std::vector<std::string> additional_stops;// std::unique_ptr<class common_chat_parser> parser;
|
std::vector<std::string> additional_stops;
|
||||||
common_chat_parser parser;
|
|
||||||
std::string format; // For debugging and testing.
|
|
||||||
bool grammar_lazy = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
||||||
|
std::string common_chat_format_name(common_chat_format format);
|
||||||
|
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||||
|
|
|
@ -117,7 +117,7 @@ struct slot_params {
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
common_chat_parser chat_parser;
|
common_chat_format oaicompat_chat_format;
|
||||||
|
|
||||||
json to_json() const {
|
json to_json() const {
|
||||||
std::vector<std::string> samplers;
|
std::vector<std::string> samplers;
|
||||||
|
@ -321,27 +321,51 @@ struct server_task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// process "json_schema" and "grammar"
|
||||||
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
|
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||||
|
}
|
||||||
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
|
try {
|
||||||
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
|
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||||
|
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
params.antiprompt.clear();
|
auto it = data.find("chat_format");
|
||||||
const auto stop = data.find("stop");
|
if (it != data.end()) {
|
||||||
if (stop != data.end()) {
|
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<u_int32_t>());
|
||||||
params.antiprompt = *stop;
|
} else {
|
||||||
|
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params_base.use_jinja) {
|
{
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
const auto grammar_triggers = data.find("grammar_triggers");
|
||||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
if (grammar_triggers != data.end()) {
|
||||||
}
|
for (const auto & t : *grammar_triggers) {
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
common_grammar_trigger trigger;
|
||||||
try {
|
trigger.word = t.at("word");
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
trigger.at_start = t.at("at_start");
|
||||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
|
||||||
} catch (const std::exception & e) {
|
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
if (ids.size() == 1) {
|
||||||
|
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||||
|
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||||
|
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
if (params.sampling.grammar_lazy) {
|
||||||
|
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -380,6 +404,19 @@ struct server_task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
params.antiprompt.clear();
|
||||||
|
|
||||||
|
const auto & stop = data.find("stop");
|
||||||
|
if (stop != data.end() && stop->is_array()) {
|
||||||
|
for (const auto & word : *stop) {
|
||||||
|
if (!word.empty()) {
|
||||||
|
params.antiprompt.push_back(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto samplers = data.find("samplers");
|
const auto samplers = data.find("samplers");
|
||||||
if (samplers != data.end()) {
|
if (samplers != data.end()) {
|
||||||
|
@ -533,7 +570,7 @@ struct completion_token_output {
|
||||||
struct server_task_result_cmpl_final : server_task_result {
|
struct server_task_result_cmpl_final : server_task_result {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
common_chat_msg message;
|
std::string content;
|
||||||
llama_tokens tokens;
|
llama_tokens tokens;
|
||||||
|
|
||||||
bool stream;
|
bool stream;
|
||||||
|
@ -559,6 +596,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
|
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
@ -584,7 +622,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
json to_json_non_oaicompat() {
|
json to_json_non_oaicompat() {
|
||||||
json res = json {
|
json res = json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"content", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
|
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||||
{"tokens", stream ? llama_tokens {} : tokens},
|
{"tokens", stream ? llama_tokens {} : tokens},
|
||||||
{"id_slot", id_slot},
|
{"id_slot", id_slot},
|
||||||
{"stop", true},
|
{"stop", true},
|
||||||
|
@ -621,7 +659,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
json res = json {
|
json res = json {
|
||||||
{"choices", json::array({
|
{"choices", json::array({
|
||||||
json{
|
json{
|
||||||
{"text", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
|
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"logprobs", logprobs},
|
{"logprobs", logprobs},
|
||||||
{"finish_reason", finish_reason},
|
{"finish_reason", finish_reason},
|
||||||
|
@ -652,8 +690,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
|
|
||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
|
common_chat_msg message;
|
||||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
|
message = common_chat_parse(content, oaicompat_chat_format);
|
||||||
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
|
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
|
||||||
|
} else {
|
||||||
|
message.content = content;
|
||||||
}
|
}
|
||||||
|
|
||||||
json tool_calls;
|
json tool_calls;
|
||||||
|
@ -1189,7 +1231,6 @@ struct server_slot {
|
||||||
|
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
|
@ -1197,7 +1238,7 @@ struct server_slot {
|
||||||
|
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
|
|
||||||
common_chat_parser chat_parser;
|
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
|
@ -2282,10 +2323,11 @@ struct server_context {
|
||||||
res->id_slot = slot.id;
|
res->id_slot = slot.id;
|
||||||
|
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->tokens = slot.generated_tokens;
|
res->content = std::move(slot.generated_text);
|
||||||
|
res->tokens = std::move(slot.generated_tokens);
|
||||||
res->timings = slot.get_timings();
|
res->timings = slot.get_timings();
|
||||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||||
res->response_fields = slot.params.response_fields;
|
res->response_fields = std::move(slot.params.response_fields);
|
||||||
|
|
||||||
res->truncated = slot.truncated;
|
res->truncated = slot.truncated;
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
|
@ -2296,21 +2338,12 @@ struct server_context {
|
||||||
res->stop = slot.stop;
|
res->stop = slot.stop;
|
||||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->stream = slot.params.stream;
|
res->stream = slot.params.stream;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
if (slot.params.chat_parser) {
|
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||||
LOG_DBG("Raw chat output: %s\n", slot.generated_text.c_str());
|
|
||||||
res->message = slot.params.chat_parser(slot.generated_text);
|
|
||||||
} else {
|
|
||||||
res->message = {
|
|
||||||
/* .role = */ "assistant",
|
|
||||||
/* .content = */ std::move(slot.generated_text),
|
|
||||||
/* .tool_calls = */ {}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||||
|
@ -3773,8 +3806,7 @@ int main(int argc, char ** argv) {
|
||||||
json & data,
|
json & data,
|
||||||
std::function<bool()> is_connection_closed,
|
std::function<bool()> is_connection_closed,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
oaicompat_type oaicompat,
|
oaicompat_type oaicompat) {
|
||||||
const common_chat_template * tmpl = nullptr) {
|
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
|
@ -3786,40 +3818,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
common_chat_params chat_params;
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||||
bool add_special = false;
|
|
||||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
|
||||||
chat_params = common_chat_params_init(*tmpl, {
|
|
||||||
/* .messages = */ json_value(data, "messages", json::array()),
|
|
||||||
/* .tools = */ json_value(data, "tools", json()),
|
|
||||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
|
||||||
/* .json_schema = */ json_value(data, "json_schema", json()),
|
|
||||||
/* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false),
|
|
||||||
/* .stream = */ json_value(data, "stream", false),
|
|
||||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
|
||||||
});
|
|
||||||
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
|
|
||||||
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
|
|
||||||
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
|
|
||||||
if (data.contains("grammar")) {
|
|
||||||
if (!chat_params.grammar.empty()) {
|
|
||||||
throw std::runtime_error("Cannot provide grammar and tools");
|
|
||||||
}
|
|
||||||
chat_params.grammar = data.at("grammar");
|
|
||||||
}
|
|
||||||
// TODO: move inside minja:chat_template?
|
|
||||||
add_special = tmpl->source().find("eos_token") == std::string::npos &&
|
|
||||||
tmpl->source().find("bos_token") == std::string::npos;
|
|
||||||
} else {
|
|
||||||
add_special = true;
|
|
||||||
chat_params.prompt = data.at("prompt");
|
|
||||||
if (data.contains("grammar")) {
|
|
||||||
chat_params.grammar = data.at("grammar");
|
|
||||||
} else if (data.contains("json_schema")) {
|
|
||||||
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
|
|
||||||
tasks.reserve(tokenized_prompts.size());
|
tasks.reserve(tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
@ -3837,25 +3836,6 @@ int main(int argc, char ** argv) {
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.oaicompat = oaicompat;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
|
|
||||||
// Grammar & tool-calls
|
|
||||||
task.params.sampling.grammar = chat_params.grammar;
|
|
||||||
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
|
|
||||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
|
||||||
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
|
||||||
if (ids.size() == 1) {
|
|
||||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
|
||||||
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
|
||||||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
|
||||||
}
|
|
||||||
task.params.antiprompt = chat_params.additional_stops;
|
|
||||||
task.params.chat_parser = chat_params.parser;
|
|
||||||
if (task.params.sampling.grammar_lazy) {
|
|
||||||
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
|
|
||||||
}
|
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(task);
|
||||||
|
@ -4039,8 +4019,7 @@ int main(int argc, char ** argv) {
|
||||||
data,
|
data,
|
||||||
req.is_connection_closed,
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_CHAT,
|
OAICOMPAT_TYPE_CHAT);
|
||||||
&chat_template);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||||
|
|
|
@ -630,11 +630,34 @@ static json oaicompat_completion_params_parse(
|
||||||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||||
}
|
}
|
||||||
llama_params["tool_choice"] = tool_choice;
|
|
||||||
llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false);
|
|
||||||
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
||||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||||
}
|
}
|
||||||
|
common_chat_inputs inputs;
|
||||||
|
inputs.messages = body.at("messages");
|
||||||
|
inputs.tools = tools;
|
||||||
|
inputs.tool_choice = tool_choice;
|
||||||
|
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||||
|
inputs.stream = stream;
|
||||||
|
// TODO: support mixing schema w/ tools beyond generic format.
|
||||||
|
inputs.json_schema = json_value(llama_params, "json_schema", json::object());
|
||||||
|
auto chat_params = common_chat_params_init(tmpl, inputs);
|
||||||
|
|
||||||
|
llama_params["chat_format"] = static_cast<u_int32_t>(chat_params.format);
|
||||||
|
llama_params["prompt"] = chat_params.prompt;
|
||||||
|
llama_params["grammar"] = chat_params.grammar;
|
||||||
|
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||||
|
auto grammar_triggers = json::array();
|
||||||
|
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||||
|
grammar_triggers.push_back({
|
||||||
|
{"word", trigger.word},
|
||||||
|
{"at_start", trigger.at_start},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
llama_params["grammar_triggers"] = grammar_triggers;
|
||||||
|
for (const auto & stop : chat_params.additional_stops) {
|
||||||
|
llama_params["stop"].push_back(stop);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,9 +170,9 @@ const json tools = {special_function_tool, python_tool};
|
||||||
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
||||||
|
|
||||||
struct delta_data {
|
struct delta_data {
|
||||||
std::string delta;
|
std::string delta;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
common_chat_parser parser;
|
common_chat_format format;
|
||||||
};
|
};
|
||||||
|
|
||||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
|
||||||
|
@ -212,7 +212,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {delta, params_full.grammar, params_full.parser};
|
return {delta, params_full.grammar, params_full.format};
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -235,7 +235,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!skip_parser_test) {
|
if (!skip_parser_test) {
|
||||||
const auto msg = data.parser(data.delta);
|
const auto msg = common_chat_parse(data.delta, data.format);
|
||||||
assert_msg_equals(expected_msg, msg);
|
assert_msg_equals(expected_msg, msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,10 +257,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string describe(const common_chat_template & tmpl, const common_chat_inputs & params) {
|
|
||||||
return common_chat_params_init(tmpl, params).format;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void test_template_output_parsers() {
|
static void test_template_output_parsers() {
|
||||||
auto text_message = json {
|
auto text_message = json {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
|
@ -315,117 +311,124 @@ static void test_template_output_parsers() {
|
||||||
inputs_tools.tools = json::array();
|
inputs_tools.tools = json::array();
|
||||||
inputs_tools.tools.push_back(special_function_tool);
|
inputs_tools.tools.push_back(special_function_tool);
|
||||||
|
|
||||||
{
|
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
|
||||||
const common_chat_template tmpl(read_file(
|
inputs_tools_builtin.tools = json::array();
|
||||||
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
inputs_tools_builtin.tools.push_back(python_tool);
|
||||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
|
// {
|
||||||
assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
// const common_chat_template tmpl(read_file(
|
||||||
assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file(
|
// "models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||||
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
// std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||||
|
|
||||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
// assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser(
|
// assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
"{\n"
|
// assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(common_chat_template(read_file(
|
||||||
" \"response\": \"Hello, world!\"\n"
|
// "models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||||
"}"));
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
|
||||||
"{\n"
|
|
||||||
" \"tool_calls\": [\n"
|
|
||||||
" {\n"
|
|
||||||
" \"name\": \"special_function\",\n"
|
|
||||||
" \"arguments\": {\n"
|
|
||||||
" \"arg1\": 1\n"
|
|
||||||
" },\n"
|
|
||||||
" \"id\": \"123456789\"\n"
|
|
||||||
" }\n"
|
|
||||||
" ]\n"
|
|
||||||
"}");
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file(
|
|
||||||
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "</s>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
// // Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
// assert_msg_equals(msg_from_json(text_message), common_chat_parse(
|
||||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
// "{\n"
|
||||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
// " \"response\": \"Hello, world!\"\n"
|
||||||
/* skip_grammar_test= */ true);
|
// "}",
|
||||||
}
|
// common_chat_params_init(tmpl, inputs_tools).format));
|
||||||
{
|
// test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||||
const common_chat_template tmpl(read_file(
|
// "{\n"
|
||||||
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
// " \"tool_calls\": [\n"
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
// " {\n"
|
||||||
|
// " \"name\": \"special_function\",\n"
|
||||||
|
// " \"arguments\": {\n"
|
||||||
|
// " \"arg1\": 1\n"
|
||||||
|
// " },\n"
|
||||||
|
// " \"id\": \"123456789\"\n"
|
||||||
|
// " }\n"
|
||||||
|
// " ]\n"
|
||||||
|
// "}");
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file(
|
||||||
|
// "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
|
// std::vector<std::string> end_tokens { "</s>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
// assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
|
|
||||||
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
|
|
||||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
// test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||||
"<tool_call>\n"
|
// "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
// /* skip_grammar_test= */ true);
|
||||||
"</tool_call>");
|
// }
|
||||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
// {
|
||||||
"<tool_call>\n"
|
// const common_chat_template tmpl(read_file(
|
||||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
// "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
"</tool_call>");
|
// std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
}
|
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file(
|
|
||||||
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format);
|
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file(
|
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(common_chat_template(read_file(
|
||||||
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
// "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||||
|
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(common_chat_template(read_file(
|
||||||
|
// "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||||
|
|
||||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
// "<tool_call>\n"
|
||||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
// "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
// "</tool_call>");
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
// test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
// "<tool_call>\n"
|
||||||
}
|
// "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||||
{
|
// "</tool_call>");
|
||||||
const common_chat_template tmpl(read_file(
|
// }
|
||||||
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
// {
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
// const common_chat_template tmpl(read_file(
|
||||||
|
// "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, common_chat_params_init(tmpl, inputs_tools_builtin).format);
|
||||||
|
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, common_chat_params_init(common_chat_template(read_file(
|
||||||
|
// "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools_builtin).format);
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
// // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
// "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
// test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||||
}
|
// "<|python_tag|>python.call(code=\"print('hey')\")");
|
||||||
{
|
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
const common_chat_template tmpl(read_file(
|
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
// }
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file(
|
||||||
|
// "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
// test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
// "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"<function=special_function>{\"arg1\": 1}</function>");
|
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
// }
|
||||||
|
// {
|
||||||
|
// const common_chat_template tmpl(read_file(
|
||||||
|
// "models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
|
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
// assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
|
||||||
|
// test_template(tmpl, end_tokens, text_message, tools,
|
||||||
|
// "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
// "<function=special_function>{\"arg1\": 1}</function>");
|
||||||
|
// }
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file(
|
const common_chat_template tmpl(read_file(
|
||||||
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||||
assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, {},
|
||||||
"all\n"
|
"all\n"
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
@ -437,7 +440,7 @@ static void test_template_output_parsers() {
|
||||||
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
@ -449,7 +452,7 @@ static void test_template_output_parsers() {
|
||||||
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
@ -480,7 +483,7 @@ int main(int argc, char **argv) {
|
||||||
common_chat_template tmpl(read_file(path), "", "");
|
common_chat_template tmpl(read_file(path), "", "");
|
||||||
auto parts = string_split(path, "/");
|
auto parts = string_split(path, "/");
|
||||||
auto name = parts[parts.size() - 1];
|
auto name = parts[parts.size() - 1];
|
||||||
std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n";
|
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format) << " |\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue