From 0595f03dd171cdad54aafe79199f223d7a2e3f0b Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Wed, 26 Jun 2024 05:58:13 +0000 Subject: [PATCH] fix chat template bug --- convert-hf-to-gguf.py | 2 +- llama.cpp | 7 ++----- tests/test-chat-template.cpp | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 363c09720..f500b3492 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -3045,7 +3045,7 @@ class ChatGLMModel(Model): self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) - special_vocab.chat_template = "ChatGLM4" + special_vocab.chat_template = "chatglm4" special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) diff --git a/llama.cpp b/llama.cpp index 76b5c3d91..4abdfa37a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19801,10 +19801,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || - (tmpl.find("add_generation_prompt") != std::string::npos && - tmpl.find("for message in messages") != std::string::npos && - tmpl.find("loop.first") != std::string::npos)) { + } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -19814,7 +19811,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "ChatGLM4") { + } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]") != std::string::npos) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index bcdfeee0e..e843be749 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -61,7 +61,7 @@ int main(void) { // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", // ChatGLM4 - "ChatGLM4", + "chatglm4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B