From 78861a3eb2f8583115cba378caad95b34c274b9c Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 19:58:15 +0000 Subject: [PATCH] Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template --- common/common.cpp | 16 ++-------------- examples/run/run.cpp | 4 ++-- examples/simple-chat/simple-chat.cpp | 2 +- include/llama.h | 2 +- src/llama-arch.cpp | 6 ++++-- src/llama-arch.h | 4 +++- src/llama-model.cpp | 6 ++++-- 7 files changed, 17 insertions(+), 23 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9cd371326..275aa7385 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1822,17 +1822,6 @@ std::string common_chat_format_example(const struct llama_model * model, return common_chat_apply_template(model, tmpl, msgs, true); } -static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { - int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - return std::string(curr_tmpl_buf.data(), tlen); - } - } - return ""; -} - llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); @@ -1841,9 +1830,8 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * std::string default_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override; if (chat_template_override.empty()) { - // TODO: - default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template"); - tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); + default_template_src = llama_model_chat_template(model, /* name */ nullptr); + tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use"); } if (default_template_src.empty() || default_template_src == "chatml") { if (!tool_use_template_src.empty()) { diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 0ad8bb15b..1c838aa77 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData & // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(LlamaData & llama_data, const bool append) { int result = llama_chat_apply_template( - llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, + llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), + result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index e8eda9c22..46aeae2a9 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { break; } - const char * tmpl = llama_model_chat_template(model); + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); // add the user input to the message list and format it messages.push_back({"user", strdup(user.c_str())}); diff --git a/include/llama.h b/include/llama.h index a184884c7..b5462157f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -503,7 +503,7 @@ extern "C" { LLAMA_API uint64_t llama_model_size(const struct llama_model * model); // Get the default chat template. Returns nullptr if not available - LLAMA_API const char * llama_model_chat_template(const struct llama_model * model); + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d7d277e72..a7260f495 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -179,6 +179,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -1443,10 +1444,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; -LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } std::string LLM_TN_IMPL::str() const { diff --git a/src/llama-arch.h b/src/llama-arch.h index 349844790..122fdcebe 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -177,6 +177,7 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -335,9 +336,10 @@ enum llm_tensor_layer { }; struct LLM_KV { - LLM_KV(llm_arch arch); + LLM_KV(llm_arch arch, const char * suffix = nullptr); llm_arch arch; + const char * suffix; std::string operator()(llm_kv kv) const; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f90f5e746..dea03c6f2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3912,8 +3912,10 @@ uint64_t llama_model_size(const struct llama_model * model) { return model->size(); } -const char * llama_model_chat_template(const struct llama_model * model) { - const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)); +const char * llama_model_chat_template(const struct llama_model * model, const char * name) { + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { return nullptr; }