Update bruteforce test: fix pyright complaints

This commit is contained in:
jaime-m-p 2024-08-05 20:58:15 +02:00
parent 735105edf9
commit fd6d9b9e6a

View file

@ -124,8 +124,7 @@ class LibLlamaModel:
text = self.detokenize([id], remove_special=False, unparse_special=True) text = self.detokenize([id], remove_special=False, unparse_special=True)
else: else:
text = self.lib.llama_token_get_text(self.model, id) text = self.lib.llama_token_get_text(self.model, id)
text = self.ffi.string(text) text = str(cast(bytes, self.ffi.string(text)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
vocab.append(text) vocab.append(text)
return vocab return vocab
@ -162,12 +161,13 @@ class TokenizerGroundtruth (Tokenizer):
self.eos_token = self.model.eos_token self.eos_token = self.model.eos_token
def get_vocab(self, detokenize=False) -> list[str]: def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
max_token_id = max(self.model.get_vocab().values()) max_token_id = max(self.model.get_vocab().values())
if detokenize: if detokenize:
ids = list(range(max_token_id + 1)) ids = list(range(max_token_id + 1))
vocab = self.model.batch_decode(ids, skip_special_tokens=False) vocab = self.model.batch_decode(ids, skip_special_tokens=False)
else: else:
vocab = [None] * (max_token_id + 1) vocab = [""] * (max_token_id + 1)
for text, id in self.model.get_vocab().items(): for text, id in self.model.get_vocab().items():
vocab[id] = text vocab[id] = text
return vocab return vocab
@ -455,14 +455,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
if len(ids1) == len(ids2):
return -1
return min(len(ids1), len(ids2))
def check_detokenizer(text: str, text1: str, text2: str) -> bool: def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth? if text1 == text2: # equal to TokenizerGroundtruth?
return True return True