Update bruteforce test: fix pyright complaints
This commit is contained in:
parent
735105edf9
commit
fd6d9b9e6a
1 changed files with 3 additions and 11 deletions
|
@ -124,8 +124,7 @@ class LibLlamaModel:
|
|||
text = self.detokenize([id], remove_special=False, unparse_special=True)
|
||||
else:
|
||||
text = self.lib.llama_token_get_text(self.model, id)
|
||||
text = self.ffi.string(text)
|
||||
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
text = str(cast(bytes, self.ffi.string(text)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
vocab.append(text)
|
||||
return vocab
|
||||
|
||||
|
@ -162,12 +161,13 @@ class TokenizerGroundtruth (Tokenizer):
|
|||
self.eos_token = self.model.eos_token
|
||||
|
||||
def get_vocab(self, detokenize=False) -> list[str]:
|
||||
vocab: list[str] = []
|
||||
max_token_id = max(self.model.get_vocab().values())
|
||||
if detokenize:
|
||||
ids = list(range(max_token_id + 1))
|
||||
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
|
||||
else:
|
||||
vocab = [None] * (max_token_id + 1)
|
||||
vocab = [""] * (max_token_id + 1)
|
||||
for text, id in self.model.get_vocab().items():
|
||||
vocab[id] = text
|
||||
return vocab
|
||||
|
@ -455,14 +455,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
|||
|
||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||
|
||||
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
|
||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||
if a != b:
|
||||
return i
|
||||
if len(ids1) == len(ids2):
|
||||
return -1
|
||||
return min(len(ids1), len(ids2))
|
||||
|
||||
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||
return True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue