split chat handler vs. parser around enum again
This commit is contained in:
parent
590c97931a
commit
f8e14bffc3
5 changed files with 501 additions and 427 deletions
489
common/chat.cpp
489
common/chat.cpp
|
@ -4,6 +4,23 @@
|
|||
#include "log.h"
|
||||
#include "minja.hpp"
|
||||
|
||||
std::string common_chat_format_name(common_chat_format format) {
|
||||
switch (format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
||||
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
}
|
||||
|
||||
const common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ false,
|
||||
|
@ -55,25 +72,21 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::optional<std::regex> & trigger_opt, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
|
||||
static common_chat_msg parse_json_tool_calls(
|
||||
const std::string& input,
|
||||
const std::optional<std::regex> & trigger_opt,
|
||||
const std::regex & function_regex,
|
||||
const std::regex & close_regex) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
|
||||
std::vector<std::string> tool_names;
|
||||
if (check_names) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
tool_names.push_back(tool["function"]["name"]);
|
||||
}
|
||||
}
|
||||
|
||||
auto end = input.end();
|
||||
auto it = input.begin();
|
||||
|
@ -96,24 +109,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
|||
break;
|
||||
}
|
||||
auto name = rit->str(1);
|
||||
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
|
||||
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
|
||||
result.content += std::string(it, rit->suffix().first);
|
||||
it = rit->suffix().first;
|
||||
continue;
|
||||
}
|
||||
|
||||
result.content += std::string(it, rit->prefix().second);
|
||||
it = rit->suffix().first;
|
||||
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
|
||||
std::string src(it, end);
|
||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
||||
break;
|
||||
}
|
||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
|
@ -162,15 +162,7 @@ static void foreach_function(const json & tools, const std::function<void(const
|
|||
}
|
||||
}
|
||||
|
||||
static common_chat_msg no_op_text_parser(const std::string & input) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
|
@ -252,35 +244,35 @@ static common_chat_params common_chat_params_init_generic_tool_call(const common
|
|||
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
||||
|
||||
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "generic tool calls";
|
||||
data.parser = [&](const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
||||
json data = json::parse(input);
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
tool_call.contains("id") ? tool_call["id"] : "",
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -318,12 +310,12 @@ static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const c
|
|||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "mistral nemo tool calls";
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
};
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
}
|
||||
|
||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
|
@ -379,7 +371,6 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|||
return true;
|
||||
};
|
||||
|
||||
auto has_function = false;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
|
@ -411,45 +402,48 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
|
||||
data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
|
||||
if (allow_python_tag_builtin_tools && !builtin_tools.empty()) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto name = match[1].str();
|
||||
auto raw_args = match[2].str();
|
||||
|
||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||
auto it_eq = raw_args.find('=');
|
||||
auto arg_name = raw_args.substr(0, it_eq);
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ match[1],
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
};
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
// TODO: tighten & simplify the parser, don't accept leading text context.
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
if (with_builtin_tools) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto name = match[1].str();
|
||||
auto raw_args = match[2].str();
|
||||
|
||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||
auto it_eq = raw_args.find('=');
|
||||
auto arg_name = raw_args.substr(0, it_eq);
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ match[1],
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -466,19 +460,23 @@ static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const co
|
|||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "deepseek r1 tool calls";
|
||||
data.parser = [inputs](const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||
};
|
||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
fprintf(stderr, "%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
}, /* adjust_inputs= */ false);
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -508,26 +506,22 @@ static common_chat_params common_chat_params_init_firefunction_v2_tool_call(cons
|
|||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
data.parser = [](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
};
|
||||
data.format = "firefunction v2 tool calls";
|
||||
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
||||
} else {
|
||||
data.parser = no_op_text_parser;
|
||||
data.format = "firefunction v2 text-only";
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
}, /* adjust_inputs= */ false);
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -552,26 +546,52 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
}
|
||||
|
||||
}, grammar_options);
|
||||
data.format = "functionary v3.2 tool calls";
|
||||
} else {
|
||||
data.format = "functionary v3.2 content-only";
|
||||
}
|
||||
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.parser = [inputs](const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||
if (res.content.find("all\n") == 0) {
|
||||
res.content = res.content.substr(4);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||
auto expected_it = expected.begin();
|
||||
auto tmp_it = it;
|
||||
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||
++tmp_it;
|
||||
++expected_it;
|
||||
}
|
||||
if (expected_it == expected.end()) {
|
||||
it = tmp_it;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
std::string content;
|
||||
auto it = input.begin();
|
||||
const auto end = input.end();
|
||||
|
||||
if (consume(it, end, "all\n")) {
|
||||
std::smatch match;
|
||||
if (std::regex_search(it, end, match, function_regex)) {
|
||||
auto fun_it = match.prefix().second;
|
||||
content = std::string(it, fun_it);
|
||||
it = fun_it;
|
||||
} else {
|
||||
common_chat_msg res;
|
||||
res.role = "assistant";
|
||||
res.content = std::string(it, end);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
// TODO: tighten & simplify.
|
||||
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
||||
res.content = content;
|
||||
return res;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
|
@ -620,33 +640,35 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_too
|
|||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "functionary v3.1 llama 3.1 tool calls";
|
||||
data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
auto code = match[1].str();
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||
};
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
auto code = match[1].str();
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ (json {{"code", code}}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
// TODO: tighten & simplify.
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar_lazy = inputs.tool_choice != "required";
|
||||
|
@ -672,69 +694,68 @@ static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const c
|
|||
}, grammar_options);
|
||||
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "hermes 2 pro tool calls";
|
||||
data.parser = [&](const std::string & input) -> common_chat_msg {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} catch (const std::exception & e) {
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
};
|
||||
return data;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
result.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} catch (const std::exception & e) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = "content-only";
|
||||
data.parser = no_op_text_parser;
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
data.grammar_lazy = false;
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
if (!inputs.grammar.empty()) {
|
||||
|
@ -748,7 +769,7 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
|||
}
|
||||
|
||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||
auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none";
|
||||
auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
|
||||
if (has_tools && !inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
|
@ -760,30 +781,64 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
|||
}
|
||||
if (src.find(" functools[") != std::string::npos) {
|
||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||
return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs);
|
||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (has_tools) {
|
||||
if (!has_tools) {
|
||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||
}
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs);
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(tmpl, inputs);
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
||||
}
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||
}
|
||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||
return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs);
|
||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||
return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs);
|
||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||
}
|
||||
return common_chat_params_init_generic_tool_call(tmpl, inputs);
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ input,
|
||||
/* .tool_calls = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
||||
switch (format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||
return common_chat_parse_content_only(input);
|
||||
case COMMON_CHAT_FORMAT_GENERIC:
|
||||
return common_chat_parse_generic(input);
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||
return common_chat_parse_mistral_nemo(input);
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||
return common_chat_parse_llama_3_1(input);
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
return common_chat_parse_deepseek_r1(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
return common_chat_parse_functionary_v3_2(input);
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||
return common_chat_parse_hermes_2_pro(input);
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
return common_chat_parse_firefunction_v2(input);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
||||
}
|
||||
}
|
|
@ -21,16 +21,30 @@ struct common_chat_inputs {
|
|||
bool add_generation_prompt = true;
|
||||
};
|
||||
|
||||
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
|
||||
enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
json prompt;
|
||||
std::string grammar;
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
json prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> additional_stops;// std::unique_ptr<class common_chat_parser> parser;
|
||||
common_chat_parser parser;
|
||||
std::string format; // For debugging and testing.
|
||||
bool grammar_lazy = false;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
||||
std::string common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||
|
|
|
@ -117,7 +117,7 @@ struct slot_params {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_parser chat_parser;
|
||||
common_chat_format oaicompat_chat_format;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
|
@ -321,27 +321,51 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
}
|
||||
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
const auto stop = data.find("stop");
|
||||
if (stop != data.end()) {
|
||||
params.antiprompt = *stop;
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<u_int32_t>());
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
}
|
||||
}
|
||||
|
||||
if (!params_base.use_jinja) {
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
{
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
for (const auto & t : *grammar_triggers) {
|
||||
common_grammar_trigger trigger;
|
||||
trigger.word = t.at("word");
|
||||
trigger.at_start = t.at("at_start");
|
||||
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
}
|
||||
if (params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -380,6 +404,19 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
|
||||
const auto & stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
for (const auto & word : *stop) {
|
||||
if (!word.empty()) {
|
||||
params.antiprompt.push_back(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto samplers = data.find("samplers");
|
||||
if (samplers != data.end()) {
|
||||
|
@ -533,7 +570,7 @@ struct completion_token_output {
|
|||
struct server_task_result_cmpl_final : server_task_result {
|
||||
int index = 0;
|
||||
|
||||
common_chat_msg message;
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
bool stream;
|
||||
|
@ -559,6 +596,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -584,7 +622,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json to_json_non_oaicompat() {
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? llama_tokens {} : tokens},
|
||||
{"id_slot", id_slot},
|
||||
{"stop", true},
|
||||
|
@ -621,7 +659,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json res = json {
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
|
||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", finish_reason},
|
||||
|
@ -652,8 +690,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg message;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
message = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
} else {
|
||||
message.content = content;
|
||||
}
|
||||
|
||||
json tool_calls;
|
||||
|
@ -1189,7 +1231,6 @@ struct server_slot {
|
|||
|
||||
std::string stopping_word;
|
||||
|
||||
|
||||
// sampling
|
||||
json json_schema;
|
||||
|
||||
|
@ -1197,7 +1238,7 @@ struct server_slot {
|
|||
|
||||
llama_token sampled;
|
||||
|
||||
common_chat_parser chat_parser;
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
@ -2282,10 +2323,11 @@ struct server_context {
|
|||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.index;
|
||||
res->tokens = slot.generated_tokens;
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->response_fields = slot.params.response_fields;
|
||||
res->response_fields = std::move(slot.params.response_fields);
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
|
@ -2296,21 +2338,12 @@ struct server_context {
|
|||
res->stop = slot.stop;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
if (slot.params.chat_parser) {
|
||||
LOG_DBG("Raw chat output: %s\n", slot.generated_text.c_str());
|
||||
res->message = slot.params.chat_parser(slot.generated_text);
|
||||
} else {
|
||||
res->message = {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ std::move(slot.generated_text),
|
||||
/* .tool_calls = */ {}
|
||||
};
|
||||
}
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
|
@ -3773,8 +3806,7 @@ int main(int argc, char ** argv) {
|
|||
json & data,
|
||||
std::function<bool()> is_connection_closed,
|
||||
httplib::Response & res,
|
||||
oaicompat_type oaicompat,
|
||||
const common_chat_template * tmpl = nullptr) {
|
||||
oaicompat_type oaicompat) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
if (ctx_server.params_base.embedding) {
|
||||
|
@ -3786,40 +3818,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
common_chat_params chat_params;
|
||||
bool add_special = false;
|
||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||
chat_params = common_chat_params_init(*tmpl, {
|
||||
/* .messages = */ json_value(data, "messages", json::array()),
|
||||
/* .tools = */ json_value(data, "tools", json()),
|
||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||
/* .json_schema = */ json_value(data, "json_schema", json()),
|
||||
/* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false),
|
||||
/* .stream = */ json_value(data, "stream", false),
|
||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||
});
|
||||
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
|
||||
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
|
||||
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
|
||||
if (data.contains("grammar")) {
|
||||
if (!chat_params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot provide grammar and tools");
|
||||
}
|
||||
chat_params.grammar = data.at("grammar");
|
||||
}
|
||||
// TODO: move inside minja:chat_template?
|
||||
add_special = tmpl->source().find("eos_token") == std::string::npos &&
|
||||
tmpl->source().find("bos_token") == std::string::npos;
|
||||
} else {
|
||||
add_special = true;
|
||||
chat_params.prompt = data.at("prompt");
|
||||
if (data.contains("grammar")) {
|
||||
chat_params.grammar = data.at("grammar");
|
||||
} else if (data.contains("json_schema")) {
|
||||
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||
}
|
||||
}
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
@ -3837,25 +3836,6 @@ int main(int argc, char ** argv) {
|
|||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
|
||||
// Grammar & tool-calls
|
||||
task.params.sampling.grammar = chat_params.grammar;
|
||||
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
task.params.antiprompt = chat_params.additional_stops;
|
||||
task.params.chat_parser = chat_params.parser;
|
||||
if (task.params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
|
||||
}
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(task);
|
||||
|
@ -4039,8 +4019,7 @@ int main(int argc, char ** argv) {
|
|||
data,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
&chat_template);
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
|
|
|
@ -630,11 +630,34 @@ static json oaicompat_completion_params_parse(
|
|||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||
}
|
||||
llama_params["tool_choice"] = tool_choice;
|
||||
llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false);
|
||||
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
common_chat_inputs inputs;
|
||||
inputs.messages = body.at("messages");
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.stream = stream;
|
||||
// TODO: support mixing schema w/ tools beyond generic format.
|
||||
inputs.json_schema = json_value(llama_params, "json_schema", json::object());
|
||||
auto chat_params = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
llama_params["chat_format"] = static_cast<u_int32_t>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
grammar_triggers.push_back({
|
||||
{"word", trigger.word},
|
||||
{"at_start", trigger.at_start},
|
||||
});
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||
}
|
||||
|
|
|
@ -170,9 +170,9 @@ const json tools = {special_function_tool, python_tool};
|
|||
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
||||
|
||||
struct delta_data {
|
||||
std::string delta;
|
||||
std::string grammar;
|
||||
common_chat_parser parser;
|
||||
std::string delta;
|
||||
std::string grammar;
|
||||
common_chat_format format;
|
||||
};
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
|
||||
|
@ -212,7 +212,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|||
break;
|
||||
}
|
||||
}
|
||||
return {delta, params_full.grammar, params_full.parser};
|
||||
return {delta, params_full.grammar, params_full.format};
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -235,7 +235,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
}
|
||||
|
||||
if (!skip_parser_test) {
|
||||
const auto msg = data.parser(data.delta);
|
||||
const auto msg = common_chat_parse(data.delta, data.format);
|
||||
assert_msg_equals(expected_msg, msg);
|
||||
}
|
||||
|
||||
|
@ -257,10 +257,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
}
|
||||
}
|
||||
|
||||
static std::string describe(const common_chat_template & tmpl, const common_chat_inputs & params) {
|
||||
return common_chat_params_init(tmpl, params).format;
|
||||
}
|
||||
|
||||
static void test_template_output_parsers() {
|
||||
auto text_message = json {
|
||||
{"role", "assistant"},
|
||||
|
@ -315,117 +311,124 @@ static void test_template_output_parsers() {
|
|||
inputs_tools.tools = json::array();
|
||||
inputs_tools.tools.push_back(special_function_tool);
|
||||
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
|
||||
inputs_tools_builtin.tools = json::array();
|
||||
inputs_tools_builtin.tools.push_back(python_tool);
|
||||
|
||||
assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file(
|
||||
// "models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
// std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser(
|
||||
"{\n"
|
||||
" \"response\": \"Hello, world!\"\n"
|
||||
"}"));
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
" \"name\": \"special_function\",\n"
|
||||
" \"arguments\": {\n"
|
||||
" \"arg1\": 1\n"
|
||||
" },\n"
|
||||
" \"id\": \"123456789\"\n"
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "</s>" };
|
||||
// assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(common_chat_template(read_file(
|
||||
// "models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
|
||||
assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
// // Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||
/* skip_grammar_test= */ true);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
// assert_msg_equals(msg_from_json(text_message), common_chat_parse(
|
||||
// "{\n"
|
||||
// " \"response\": \"Hello, world!\"\n"
|
||||
// "}",
|
||||
// common_chat_params_init(tmpl, inputs_tools).format));
|
||||
// test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
// "{\n"
|
||||
// " \"tool_calls\": [\n"
|
||||
// " {\n"
|
||||
// " \"name\": \"special_function\",\n"
|
||||
// " \"arguments\": {\n"
|
||||
// " \"arg1\": 1\n"
|
||||
// " },\n"
|
||||
// " \"id\": \"123456789\"\n"
|
||||
// " }\n"
|
||||
// " ]\n"
|
||||
// "}");
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file(
|
||||
// "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
// std::vector<std::string> end_tokens { "</s>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
// test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
// test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
// "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||
// /* skip_grammar_test= */ true);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file(
|
||||
// "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
// std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file(
|
||||
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(common_chat_template(read_file(
|
||||
// "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(common_chat_template(read_file(
|
||||
// "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
// test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
// "<tool_call>\n"
|
||||
// "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
// "</tool_call>");
|
||||
// test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
// "<tool_call>\n"
|
||||
// "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
// "</tool_call>");
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file(
|
||||
// "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, common_chat_params_init(tmpl, inputs_tools_builtin).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, common_chat_params_init(common_chat_template(read_file(
|
||||
// "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools_builtin).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
// // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
// test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
// "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
// test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
// "<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file(
|
||||
// "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
// assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
// test_template(tmpl, end_tokens, text_message, tools,
|
||||
// "Hello, world!", /* skip_grammar_test= */ true);
|
||||
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file(
|
||||
// "models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
// std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
// assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools,
|
||||
// "Hello, world!", /* skip_grammar_test= */ true);
|
||||
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
// "<function=special_function>{\"arg1\": 1}</function>");
|
||||
// }
|
||||
{
|
||||
const common_chat_template tmpl(read_file(
|
||||
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
test_template(tmpl, end_tokens, text_message, {},
|
||||
"all\n"
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
|
@ -437,7 +440,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -449,7 +452,7 @@ static void test_template_output_parsers() {
|
|||
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"Hello, world!", /* skip_grammar_test= */ true);
|
||||
|
@ -480,7 +483,7 @@ int main(int argc, char **argv) {
|
|||
common_chat_template tmpl(read_file(path), "", "");
|
||||
auto parts = string_split(path, "/");
|
||||
auto name = parts[parts.size() - 1];
|
||||
std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n";
|
||||
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format) << " |\n";
|
||||
}
|
||||
}
|
||||
else
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue