support glm-4-9b-chat

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
This commit is contained in:
XingXing Qiao 2024-06-19 15:16:23 +08:00
parent 35c6887f78
commit de3c909db0
6 changed files with 151 additions and 2 deletions

View file

@ -483,6 +483,9 @@ class Model:
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code" res = "jina-v2-code"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -2729,7 +2732,7 @@ class DeepseekV2Model(Model):
class ChatGLMModel(Model): class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM model_arch = gguf.MODEL_ARCH.CHATGLM
def set_vocab(self): def set_vocab_chatglm3(self):
dir_model = self.dir_model dir_model = self.dir_model
hparams = self.hparams hparams = self.hparams
tokens: list[bytearray] = [] tokens: list[bytearray] = []
@ -2789,6 +2792,95 @@ class ChatGLMModel(Model):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
@staticmethod
def token_bytes_to_string(b):
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
byte_encoder = bytes_to_unicode()
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
@staticmethod
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts
def set_vocab(self):
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
self.set_vocab_chatglm3()
return
dir_model = self.dir_model
hparams = self.hparams
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["padded_vocab_size"]
assert max(tokenizer.get_vocab().values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) >= 2 and len(merged) <= 7
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.chat_template = "ChatGLM4"
special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"])
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|observation|>"])
special_vocab._set_special_token("eot", 151336)
# this one is usually not in config.json anyway
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_name(self.dir_model.name)
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))

View file

@ -717,6 +717,8 @@ return html`
<option value="vicuna">Tess</option> <option value="vicuna">Tess</option>
<option value="yi34b">Yi-6/9/34B-Chat</option> <option value="yi34b">Yi-6/9/34B-Chat</option>
<option value="zephyr">Zephyr</option> <option value="zephyr">Zephyr</option>
<option value="chatglm3">ChatGLM3-6B</option>
<option value="chatglm4">ChatGLM4-9B</option>
<option value=""></option> <option value=""></option>
</optgroup> </optgroup>
</select> </select>

View file

@ -327,5 +327,41 @@ export const promptFormats = {
userMsgSuffix: "", userMsgSuffix: "",
stops: "" stops: ""
},
// ----------------------------
"chatglm3": {
template: `[gMASK]sop<|system|>\n {{prompt}}{{history}}<|{{char}}|>`,
historyTemplate: `<|{{name}}|>\n {{message}}`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "user",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"chatglm4": {
template: `[gMASK]<sop><|system|>\n{{prompt}}{{history}}<|{{char}}|>`,
historyTemplate: `<|{{name}}|>\n{{message}}`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "user",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
} }
}; };

View file

@ -4730,6 +4730,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 28: model.type = e_model::MODEL_7B; break; case 28: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
@ -4922,6 +4923,9 @@ static void llm_load_vocab(
} else if ( } else if (
tokenizer_pre == "poro-chat") { tokenizer_pre == "poro-chat") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
} else if (
tokenizer_pre == "chatglm-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }
@ -13369,6 +13373,7 @@ struct llm_tokenizer_bpe {
break; break;
case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_DBRX:
case LLAMA_VOCAB_PRE_TYPE_SMAUG: case LLAMA_VOCAB_PRE_TYPE_SMAUG:
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
regex_exprs = { regex_exprs = {
// same as llama3 // same as llama3
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
@ -18914,6 +18919,15 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl.find("ChatGLM4") != std::string::npos) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View file

@ -87,6 +87,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15, LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16,
}; };
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope

View file

@ -59,6 +59,8 @@ int main(void) {
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
// ChatGLM3 // ChatGLM3
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// ChatGLM4
"ChatGLM4",
}; };
std::vector<std::string> expected_output = { std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B // teknium/OpenHermes-2.5-Mistral-7B
@ -97,6 +99,8 @@ int main(void) {
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
// ChatGLM3 // ChatGLM3
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
// ChatGLM4
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
}; };
std::vector<char> formatted_chat(1024); std::vector<char> formatted_chat(1024);
int32_t res; int32_t res;