From 68db76595e1e9c87b4263a930047711282986b71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Jan 2025 15:47:13 +0200 Subject: [PATCH] llama : update llama_chat_apply_template ggml-ci --- common/common.cpp | 17 ++++++++--------- examples/run/run.cpp | 4 ++-- examples/server/server.cpp | 3 ++- examples/simple-chat/simple-chat.cpp | 8 +++++--- include/llama.h | 5 +++-- src/llama-model.cpp | 9 +++++++++ src/llama.cpp | 16 +--------------- tests/test-chat-template.cpp | 3 +-- 8 files changed, 31 insertions(+), 34 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c6e3da378..cf846721c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1648,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) { bool common_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } @@ -1659,16 +1659,16 @@ std::string common_chat_apply_template(const struct llama_model * model, int alloc_size = 0; bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; - for (auto & msg : msgs) { + for (const auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1676,18 +1676,17 @@ std::string common_chat_apply_template(const struct llama_model * model, // if the custom "tmpl" is not supported, we throw an error // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() throw std::runtime_error("this custom template is not supported"); - } else { - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; } + + // If the built-in template is not supported, we default to chatml + res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + fallback = true; } // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); res = llama_chat_apply_template( - fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 5d89efc2c..d53d8c07e 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData & // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(LlamaData & llama_data, const bool append) { int result = llama_chat_apply_template( - llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, + llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), + result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 22b5583cb..f11922339 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1740,7 +1740,8 @@ struct server_context { bool validate_builtin_chat_template() const { llama_chat_message chat[] = {{"user", "test"}}; - int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); + const char * tmpl = llama_model_chat_template(model); + const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); return chat_res > 0; } diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 74657de56..3888b9d43 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -161,12 +161,14 @@ int main(int argc, char ** argv) { break; } + const char * tmpl = llama_model_chat_template(model); + // add the user input to the message list and format it messages.push_back({"user", strdup(user.c_str())}); - int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); if (new_len > (int)formatted.size()) { formatted.resize(new_len); - new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); } if (new_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); @@ -183,7 +185,7 @@ int main(int argc, char ** argv) { // add the response to the messages messages.push_back({"assistant", strdup(response.c_str())}); - prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); + prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0); if (prev_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; diff --git a/include/llama.h b/include/llama.h index 6418f565b..ee634540a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -489,6 +489,9 @@ extern "C" { // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Get the default chat template. Returns nullptr if not available + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -1009,9 +1012,7 @@ extern "C" { /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// @param length The size of the allocated buffer /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. - /// TODO: change to llama_vocab LLAMA_API int32_t llama_chat_apply_template( - const struct llama_model * model, const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d4808083a..7dad2beeb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3839,6 +3839,15 @@ uint64_t llama_model_size(const struct llama_model * model) { return model->size(); } +const char * llama_model_chat_template(const struct llama_model * model) { + const auto & it = model->gguf_kv.find("tokenizer.chat_template"); + if (it == model->gguf_kv.end()) { + return nullptr; + } + + return it->second.c_str(); +} + uint64_t llama_model_n_params(const struct llama_model * model) { return model->n_elements(); } diff --git a/src/llama.cpp b/src/llama.cpp index ff0491c4e..0ea250ad0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9834,27 +9834,13 @@ int32_t llama_decode( // int32_t llama_chat_apply_template( - const struct llama_model * model, const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, bool add_ass, char * buf, int32_t length) { - std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); - if (tmpl == nullptr) { - GGML_ASSERT(model != nullptr); - - // load template from model, if available - const auto & it = model->gguf_kv.find("tokenizer.chat_template"); - if (it != model->gguf_kv.end() && it->second.size() > 0) { - curr_tmpl = it->second; - } - else { - // worst case: there is no information about template, we will use chatml by default - curr_tmpl = "chatml"; // see llm_chat_apply_template - } - } + const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl); // format the chat to string std::vector chat_vec; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index f1f9aec4d..77d386954 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -157,7 +157,7 @@ int main(void) { } // test invalid chat template - res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); + res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); for (size_t i = 0; i < templates.size(); i++) { @@ -165,7 +165,6 @@ int main(void) { std::string expected = expected_output[i]; formatted_chat.resize(1024); res = llama_chat_apply_template( - nullptr, custom_template.c_str(), conversation, message_count,