Pass template group to common_chat_apply_template
This commit is contained in:
parent
4e8beb0c53
commit
34370803fb
6 changed files with 26 additions and 18 deletions
|
@ -1831,12 +1831,15 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_apply_template(
|
std::string common_chat_apply_template(
|
||||||
const common_chat_template & tmpl,
|
const common_chat_templates & tmpl,
|
||||||
const std::vector<common_chat_msg> & msgs,
|
const std::vector<common_chat_msg> & msgs,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const common_params_tools & tools)
|
const common_params_tools & tools)
|
||||||
{
|
{
|
||||||
|
const auto & tmpl_selected =
|
||||||
|
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
|
||||||
|
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
common_chat_inputs inputs;
|
common_chat_inputs inputs;
|
||||||
|
|
||||||
|
@ -1844,9 +1847,11 @@ std::string common_chat_apply_template(
|
||||||
for (const auto & msg : msgs) {
|
for (const auto & msg : msgs) {
|
||||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tools.tools() != nullptr) {
|
if (tools.tools() != nullptr) {
|
||||||
inputs.tools = *tools.tools();
|
inputs.tools = *tools.tools();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto choice = tools.choice();
|
auto choice = tools.choice();
|
||||||
if (std::holds_alternative<std::string>(choice)) {
|
if (std::holds_alternative<std::string>(choice)) {
|
||||||
inputs.tool_choice = std::get<std::string>(choice);
|
inputs.tool_choice = std::get<std::string>(choice);
|
||||||
|
@ -1857,9 +1862,10 @@ std::string common_chat_apply_template(
|
||||||
inputs.tool_choice = *choice_ptr;
|
inputs.tool_choice = *choice_ptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs.messages = messages;
|
inputs.messages = messages;
|
||||||
inputs.add_generation_prompt = add_ass;
|
inputs.add_generation_prompt = add_ass;
|
||||||
return common_chat_params_init(tmpl, inputs).prompt;
|
return common_chat_params_init(tmpl_selected, inputs).prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
|
@ -1872,7 +1878,7 @@ std::string common_chat_apply_template(
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
|
||||||
// error: chat template is not supported
|
// error: chat template is not supported
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
|
@ -1884,7 +1890,7 @@ std::string common_chat_apply_template(
|
||||||
// if it turns out that our buffer is too small, we resize it
|
// if it turns out that our buffer is too small, we resize it
|
||||||
if ((size_t) res > buf.size()) {
|
if ((size_t) res > buf.size()) {
|
||||||
buf.resize(res);
|
buf.resize(res);
|
||||||
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
std::string formatted_chat(buf.data(), res);
|
||||||
|
@ -1892,7 +1898,7 @@ std::string common_chat_apply_template(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_single(
|
std::string common_chat_format_single(
|
||||||
const common_chat_template & tmpl,
|
const common_chat_templates & tmpl,
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
|
@ -1916,7 +1922,7 @@ std::string common_chat_format_single(
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
|
std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) {
|
||||||
std::vector<common_chat_msg> msgs = {
|
std::vector<common_chat_msg> msgs = {
|
||||||
{"system", "You are a helpful assistant", {}},
|
{"system", "You are a helpful assistant", {}},
|
||||||
{"user", "Hello", {}},
|
{"user", "Hello", {}},
|
||||||
|
|
|
@ -675,7 +675,7 @@ struct common_chat_templates {
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
// If the custom "tmpl" is not supported, we throw an error
|
||||||
std::string common_chat_apply_template(
|
std::string common_chat_apply_template(
|
||||||
const common_chat_template & tmpl,
|
const common_chat_templates & tmpl,
|
||||||
const std::vector<common_chat_msg> & chat,
|
const std::vector<common_chat_msg> & chat,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
|
@ -683,7 +683,7 @@ std::string common_chat_apply_template(
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
std::string common_chat_format_single(
|
std::string common_chat_format_single(
|
||||||
const common_chat_template & tmpl,
|
const common_chat_templates & tmpl,
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
|
@ -692,7 +692,7 @@ std::string common_chat_format_single(
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string common_chat_format_example(
|
std::string common_chat_format_example(
|
||||||
const common_chat_template & tmpl, bool use_jinja);
|
const common_chat_templates & tmpl, bool use_jinja);
|
||||||
|
|
||||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||||
|
|
||||||
|
|
|
@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
|
||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str());
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||||
}
|
}
|
||||||
|
@ -268,9 +268,9 @@ int main(int argc, char ** argv) {
|
||||||
const common_params_tools & tools = common_params_tools())
|
const common_params_tools & tools = common_params_tools())
|
||||||
{
|
{
|
||||||
common_chat_msg new_msg{role, content, {}};
|
common_chat_msg new_msg{role, content, {}};
|
||||||
auto formatted = common_chat_format_single(
|
|
||||||
*chat_templates.template_default, chat_msgs, new_msg, role == "user",
|
auto formatted = common_chat_format_single(chat_templates, chat_msgs,
|
||||||
g_params->use_jinja, tools);
|
new_msg, role == "user", g_params->use_jinja, tools);
|
||||||
|
|
||||||
chat_msgs.push_back({role, content, {}});
|
chat_msgs.push_back({role, content, {}});
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||||
|
|
|
@ -4468,7 +4468,7 @@ int main(int argc, char ** argv) {
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||||
ctx_server.chat_templates.template_default->source().c_str(),
|
ctx_server.chat_templates.template_default->source().c_str(),
|
||||||
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
|
common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
||||||
ctx_server.process_single_task(task);
|
ctx_server.process_single_task(task);
|
||||||
|
|
|
@ -348,7 +348,7 @@ static llama_tokens format_infill(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const common_chat_templates & tmpl, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
|
@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["stop"].push_back(stop);
|
llama_params["stop"].push_back(stop);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
llama_params["prompt"] = format_chat(chat_templates, body.at("messages"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle "n" field
|
// Handle "n" field
|
||||||
|
|
|
@ -339,7 +339,8 @@ int main(void) {
|
||||||
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
|
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
|
||||||
|
|
||||||
auto fmt_sys = [&](std::string tmpl_str) {
|
auto fmt_sys = [&](std::string tmpl_str) {
|
||||||
minja::chat_template tmpl(tmpl_str, "", "");
|
common_chat_templates tmpl;
|
||||||
|
tmpl.template_default.reset(new common_chat_template(tmpl_str, "", ""));
|
||||||
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
|
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
|
||||||
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
|
@ -366,7 +367,8 @@ int main(void) {
|
||||||
common_chat_msg new_msg{"user", "How are you", {}};
|
common_chat_msg new_msg{"user", "How are you", {}};
|
||||||
|
|
||||||
auto fmt_single = [&](std::string tmpl_str) {
|
auto fmt_single = [&](std::string tmpl_str) {
|
||||||
minja::chat_template tmpl(tmpl_str, "", "");
|
common_chat_templates tmpl;
|
||||||
|
tmpl.template_default.reset(new common_chat_template(tmpl_str, "", ""));
|
||||||
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
|
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
|
||||||
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue