Refactor common_chat_* functions to accept minja template + use_jinja option

This commit is contained in:
ochafik 2025-01-18 00:43:38 +00:00
parent 3ed670b6dd
commit b75d0622e4
7 changed files with 82 additions and 80 deletions

View file

@ -74,6 +74,15 @@
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
const char * LLAMA_CHATML_TEMPLATE = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
//
// CURL utils
//
@ -1748,56 +1757,56 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
return res >= 0;
}
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_apply_template(
const llama_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass) {
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_ass);
}
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str();
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_format_single(
const llama_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass) {
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1805,29 +1814,20 @@ std::string common_chat_format_single(const struct llama_model * model,
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) {
std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
const auto add_generation_prompt = true;
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt);
} else {
return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt);
}
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
@ -1847,14 +1847,7 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
if (!tool_use_template_src.empty()) {
default_template_src = tool_use_template_src;
} else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
default_template_src = LLAMA_CHATML_TEMPLATE;
}
}
return {