support glm-4-9b-chat

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
This commit is contained in:
XingXing Qiao 2024-06-17 10:08:52 +08:00
parent f3bc337f43
commit 1fc5bf5bcb
5 changed files with 116 additions and 7 deletions

View file

@ -476,6 +476,9 @@ class Model:
if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d":
# ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct
res = "smaug-bpe"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if res is None:
logger.warning("\n")
@ -2714,7 +2717,7 @@ class DeepseekV2Model(Model):
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM
def set_vocab(self):
def set_vocab_chatglm3(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
@ -2725,7 +2728,8 @@ class ChatGLMModel(Model):
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
assert max(tokenizer.get_vocab().values()) < vocab_size
print(vocab_size)
print(max(tokenizer.get_vocab().values()))
for token_id in range(vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
if token_id == 0:
@ -2774,6 +2778,91 @@ class ChatGLMModel(Model):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
@staticmethod
def token_bytes_to_string(b):
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
byte_encoder = bytes_to_unicode()
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
@staticmethod
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts
def set_vocab(self):
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
self.set_vocab_chatglm3()
return
dir_model = self.dir_model
hparams = self.hparams
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["padded_vocab_size"]
assert max(tokenizer.get_vocab().values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) >= 2 and len(merged) <= 7
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.chat_template = "ChatGLM4"
special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json
if len(special_vocab.special_token_ids) == 0:
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
# this one is usually not in config.json anyway
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
@ -2934,7 +3023,8 @@ def main() -> None:
with torch.inference_mode():
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
print(model_class)
print(model_instance)
logger.info("Set model parameters")
model_instance.set_gguf_parameters()