convert : fix handling of added tokens
This commit is contained in:
parent
846c51feac
commit
8c8b5b0666
4 changed files with 24 additions and 36 deletions
|
@ -137,21 +137,18 @@ reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
|||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
|
||||
added_token_ids = tokenizer.get_added_vocab().values()
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i in reverse_vocab:
|
||||
try:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
except KeyError:
|
||||
text = bytearray()
|
||||
for c in reverse_vocab[i]:
|
||||
if ord(c) < 256: # single byte character
|
||||
text.append(byte_decoder[ord(c)])
|
||||
else: # multibyte special token character
|
||||
text.extend(c.encode('utf-8'))
|
||||
else:
|
||||
if i not in reverse_vocab:
|
||||
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
|
||||
pad_token = f"[PAD{i}]".encode("utf8")
|
||||
text = bytearray(pad_token)
|
||||
else if i in added_tokens:
|
||||
# these tokens are not encoded, see https://github.com/huggingface/transformers/issues/1133
|
||||
text = bytearray(reverse_vocab[i].encode('utf-8'))
|
||||
else:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
|
||||
tokens.append(text)
|
||||
|
||||
|
|
|
@ -133,21 +133,18 @@ reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
|||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
|
||||
added_token_ids = tokenizer.get_added_vocab().values()
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i in reverse_vocab:
|
||||
try:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
except KeyError:
|
||||
text = bytearray()
|
||||
for c in reverse_vocab[i]:
|
||||
if ord(c) < 256: # single byte character
|
||||
text.append(byte_decoder[ord(c)])
|
||||
else: # multibyte special token character
|
||||
text.extend(c.encode('utf-8'))
|
||||
else:
|
||||
if i not in reverse_vocab:
|
||||
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
|
||||
pad_token = f"[PAD{i}]".encode("utf8")
|
||||
text = bytearray(pad_token)
|
||||
else if i in added_tokens:
|
||||
# these tokens are not encoded, see https://github.com/huggingface/transformers/issues/1133
|
||||
text = bytearray(reverse_vocab[i].encode('utf-8'))
|
||||
else:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
|
||||
tokens.append(text)
|
||||
|
||||
|
|
|
@ -121,21 +121,18 @@ reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
|||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
|
||||
added_token_ids = tokenizer.get_added_vocab().values()
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i in reverse_vocab:
|
||||
try:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
except KeyError:
|
||||
text = bytearray()
|
||||
for c in reverse_vocab[i]:
|
||||
if ord(c) < 256: # single byte character
|
||||
text.append(byte_decoder[ord(c)])
|
||||
else: # multibyte special token character
|
||||
text.extend(c.encode('utf-8'))
|
||||
else:
|
||||
if i not in reverse_vocab:
|
||||
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
|
||||
pad_token = f"[PAD{i}]".encode("utf8")
|
||||
text = bytearray(pad_token)
|
||||
else if i in added_tokens:
|
||||
# these tokens are not encoded, see https://github.com/huggingface/transformers/issues/1133
|
||||
text = bytearray(reverse_vocab[i].encode('utf-8'))
|
||||
else:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
|
||||
tokens.append(text)
|
||||
|
||||
|
|
|
@ -338,9 +338,6 @@ class BpeVocab:
|
|||
|
||||
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
tokenizer = self.bpe_tokenizer
|
||||
from transformers.models.gpt2 import tokenization_gpt2 # type: ignore[import]
|
||||
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
score = 0.0
|
||||
for i, item in enumerate(tokenizer):
|
||||
text: bytes = item.encode("utf-8")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue