llama : update llama_chat_apply_template

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-09 15:47:13 +02:00
parent 22b31cd16d
commit 68db76595e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 31 additions and 34 deletions

View file

@ -1648,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
bool common_chat_verify_template(const std::string & tmpl) { bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0; return res >= 0;
} }
@ -1659,16 +1659,16 @@ std::string common_chat_apply_template(const struct llama_model * model,
int alloc_size = 0; int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat; std::vector<llama_chat_message> chat;
for (auto & msg : msgs) { for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()}); chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25; alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
} }
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported // error: chat template is not supported
if (res < 0) { if (res < 0) {
@ -1676,18 +1676,17 @@ std::string common_chat_apply_template(const struct llama_model * model,
// if the custom "tmpl" is not supported, we throw an error // if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported"); throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
} }
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
} }
// if it turns out that our buffer is too small, we resize it // if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) { if ((size_t) res > buf.size()) {
buf.resize(res); buf.resize(res);
res = llama_chat_apply_template( res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl, fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size()); chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }

View file

@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) { static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template( int result = llama_chat_apply_template(
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) { if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result); llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size()); llama_data.fmtted.size());
} }

View file

@ -1740,7 +1740,8 @@ struct server_context {
bool validate_builtin_chat_template() const { bool validate_builtin_chat_template() const {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); const char * tmpl = llama_model_chat_template(model);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0; return chat_res > 0;
} }

View file

@ -161,12 +161,14 @@ int main(int argc, char ** argv) {
break; break;
} }
const char * tmpl = llama_model_chat_template(model);
// add the user input to the message list and format it // add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())}); messages.push_back({"user", strdup(user.c_str())});
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
if (new_len > (int)formatted.size()) { if (new_len > (int)formatted.size()) {
formatted.resize(new_len); formatted.resize(new_len);
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
} }
if (new_len < 0) { if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
@ -183,7 +185,7 @@ int main(int argc, char ** argv) {
// add the response to the messages // add the response to the messages
messages.push_back({"assistant", strdup(response.c_str())}); messages.push_back({"assistant", strdup(response.c_str())});
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
if (prev_len < 0) { if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
return 1; return 1;

View file

@ -489,6 +489,9 @@ extern "C" {
// Returns the total size of all the tensors in the model in bytes // Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model); LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Get the default chat template. Returns nullptr if not available
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
// Returns the total number of parameters in the model // Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
@ -1009,9 +1012,7 @@ extern "C" {
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
/// @param length The size of the allocated buffer /// @param length The size of the allocated buffer
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
/// TODO: change to llama_vocab
LLAMA_API int32_t llama_chat_apply_template( LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl, const char * tmpl,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,

View file

@ -3839,6 +3839,15 @@ uint64_t llama_model_size(const struct llama_model * model) {
return model->size(); return model->size();
} }
const char * llama_model_chat_template(const struct llama_model * model) {
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
if (it == model->gguf_kv.end()) {
return nullptr;
}
return it->second.c_str();
}
uint64_t llama_model_n_params(const struct llama_model * model) { uint64_t llama_model_n_params(const struct llama_model * model) {
return model->n_elements(); return model->n_elements();
} }

View file

@ -9834,27 +9834,13 @@ int32_t llama_decode(
// //
int32_t llama_chat_apply_template( int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl, const char * tmpl,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
bool add_ass, bool add_ass,
char * buf, char * buf,
int32_t length) { int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model, if available
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
if (it != model->gguf_kv.end() && it->second.size() > 0) {
curr_tmpl = it->second;
}
else {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "chatml"; // see llm_chat_apply_template
}
}
// format the chat to string // format the chat to string
std::vector<const llama_chat_message *> chat_vec; std::vector<const llama_chat_message *> chat_vec;

View file

@ -157,7 +157,7 @@ int main(void) {
} }
// test invalid chat template // test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0); assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) { for (size_t i = 0; i < templates.size(); i++) {
@ -165,7 +165,6 @@ int main(void) {
std::string expected = expected_output[i]; std::string expected = expected_output[i];
formatted_chat.resize(1024); formatted_chat.resize(1024);
res = llama_chat_apply_template( res = llama_chat_apply_template(
nullptr,
custom_template.c_str(), custom_template.c_str(),
conversation, conversation,
message_count, message_count,