fix chat template bug

This commit is contained in:
toyer 2024-06-26 05:58:13 +00:00
parent e18a5365c3
commit 0595f03dd1
3 changed files with 4 additions and 7 deletions

View file

@ -3045,7 +3045,7 @@ class ChatGLMModel(Model):
self.gguf_writer.add_token_types(toktypes) self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.chat_template = "ChatGLM4" special_vocab.chat_template = "chatglm4"
special_vocab.merges = merges special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json # only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])

View file

@ -19801,10 +19801,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) { if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
} }
} else if (tmpl == "chatglm3" || } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) {
(tmpl.find("add_generation_prompt") != std::string::npos &&
tmpl.find("for message in messages") != std::string::npos &&
tmpl.find("loop.first") != std::string::npos)) {
// chatglm3-6b // chatglm3-6b
ss << "[gMASK]" << "sop"; ss << "[gMASK]" << "sop";
for (auto message : chat) { for (auto message : chat) {
@ -19814,7 +19811,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == "ChatGLM4") { } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]<sop>") != std::string::npos) {
ss << "[gMASK]" << "<sop>"; ss << "[gMASK]" << "<sop>";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);

View file

@ -61,7 +61,7 @@ int main(void) {
// ChatGLM3 // ChatGLM3
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// ChatGLM4 // ChatGLM4
"ChatGLM4", "chatglm4",
}; };
std::vector<std::string> expected_output = { std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B // teknium/OpenHermes-2.5-Mistral-7B