Add chat template support for llama-cli (#8068)

* add chat template support for llama-cli

* add help message

* server: simplify format_chat

* more consistent naming

* improve

* add llama_chat_format_example

* fix server

* code style

* code style

* Update examples/main/main.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan Son Nguyen 2024-06-25 13:56:49 +02:00 committed by GitHub
parent 3791ad2193
commit 48e6b92cc3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 154 additions and 49 deletions

View file

@ -1444,7 +1444,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
"negative prompt file to use for guidance" });
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"only commonly used templates are accepted:\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "grammar" });
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
@ -2604,12 +2607,67 @@ bool llama_should_add_bos_token(const llama_model * model) {
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}
//
// Chat template utils
//
bool llama_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & msgs,
bool add_ass) {
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass) {
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
std::vector<llama_chat_msg> chat_new(past_msg);
chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return formatted;
}
std::string llama_chat_format_example(const struct llama_model * model,
const std::string & tmpl) {
std::vector<llama_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
return llama_chat_apply_template(model, tmpl, msgs, true);
}
//
// KV cache utils
//