BERT tokenizer fixes (#6498)
Key changes: * BERT conversion: fix abuse of LlamaHfVocab, do not set BOS or EOS * Nomic Embed conversion: pad vocab instead of slicing embedding tensor * llama_tokenize: handle added special tokens like HF does
This commit is contained in:
parent
c4a3a4ff47
commit
1b67731e18
20 changed files with 221 additions and 194 deletions
|
@ -227,15 +227,14 @@ class Model(ABC):
|
|||
return ("pytorch_model.bin",)
|
||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||
|
||||
def _set_vocab_gpt2(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
# used for GPT-2 BPE and WordPiece vocabs
|
||||
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
|
@ -255,11 +254,15 @@ class Model(ABC):
|
|||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
return tokens, toktypes
|
||||
|
||||
def _set_vocab_gpt2(self) -> None:
|
||||
tokens, toktypes = self.get_basic_vocab()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_qwen(self):
|
||||
|
@ -2043,34 +2046,25 @@ class BertModel(Model):
|
|||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def set_vocab(self):
|
||||
# use huggingface vocab to get all tokens
|
||||
vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
|
||||
tokens, scores, toktypes = zip(*vocab.all_tokens())
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
self.vocab_size = vocab.vocab_size
|
||||
tokens, toktypes = self.get_basic_vocab()
|
||||
self.vocab_size = len(tokens)
|
||||
|
||||
# we need this to validate the size of the token_type embeddings
|
||||
# though currently we are passing all zeros to the token_type embeddings
|
||||
n_token_types = len(set(toktypes))
|
||||
self.gguf_writer.add_token_type_count(n_token_types)
|
||||
self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B"
|
||||
|
||||
# convert to phantom space vocab
|
||||
def phantom(tok, typ):
|
||||
if tok.startswith(b"[") and tok.endswith(b"]"):
|
||||
def phantom(tok):
|
||||
if tok.startswith("[") and tok.endswith("]"):
|
||||
return tok
|
||||
if tok.startswith(b"##"):
|
||||
if tok.startswith("##"):
|
||||
return tok[2:]
|
||||
return b"\xe2\x96\x81" + tok
|
||||
tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
|
||||
|
||||
# set up bos and eos tokens (cls and sep)
|
||||
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
|
||||
return "\u2581" + tok
|
||||
tokens = list(map(phantom, tokens))
|
||||
|
||||
# add vocab to gguf
|
||||
self.gguf_writer.add_tokenizer_model("bert")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# handle special tokens
|
||||
|
@ -2142,16 +2136,6 @@ class NomicBertModel(BertModel):
|
|||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||
|
||||
def get_tensors(self):
|
||||
assert self.vocab_size is not None
|
||||
for name, data in super().get_tensors():
|
||||
# Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
|
||||
if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
|
||||
rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
|
||||
assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
|
||||
data = data[:self.vocab_size, :]
|
||||
yield name, data
|
||||
|
||||
|
||||
@Model.register("GemmaForCausalLM")
|
||||
class GemmaModel(Model):
|
||||
|
@ -2327,7 +2311,8 @@ class MambaModel(Model):
|
|||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert big float32 2-dim weight tensors to float16
|
||||
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
||||
new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else ""
|
||||
if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue