Update bruteforce test: fix pyright complaints
This commit is contained in:
parent
735105edf9
commit
fd6d9b9e6a
1 changed files with 3 additions and 11 deletions
|
@ -124,8 +124,7 @@ class LibLlamaModel:
|
||||||
text = self.detokenize([id], remove_special=False, unparse_special=True)
|
text = self.detokenize([id], remove_special=False, unparse_special=True)
|
||||||
else:
|
else:
|
||||||
text = self.lib.llama_token_get_text(self.model, id)
|
text = self.lib.llama_token_get_text(self.model, id)
|
||||||
text = self.ffi.string(text)
|
text = str(cast(bytes, self.ffi.string(text)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||||
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
|
||||||
vocab.append(text)
|
vocab.append(text)
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
|
@ -162,12 +161,13 @@ class TokenizerGroundtruth (Tokenizer):
|
||||||
self.eos_token = self.model.eos_token
|
self.eos_token = self.model.eos_token
|
||||||
|
|
||||||
def get_vocab(self, detokenize=False) -> list[str]:
|
def get_vocab(self, detokenize=False) -> list[str]:
|
||||||
|
vocab: list[str] = []
|
||||||
max_token_id = max(self.model.get_vocab().values())
|
max_token_id = max(self.model.get_vocab().values())
|
||||||
if detokenize:
|
if detokenize:
|
||||||
ids = list(range(max_token_id + 1))
|
ids = list(range(max_token_id + 1))
|
||||||
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
|
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
|
||||||
else:
|
else:
|
||||||
vocab = [None] * (max_token_id + 1)
|
vocab = [""] * (max_token_id + 1)
|
||||||
for text, id in self.model.get_vocab().items():
|
for text, id in self.model.get_vocab().items():
|
||||||
vocab[id] = text
|
vocab[id] = text
|
||||||
return vocab
|
return vocab
|
||||||
|
@ -455,14 +455,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
||||||
|
|
||||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
|
|
||||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
|
||||||
if a != b:
|
|
||||||
return i
|
|
||||||
if len(ids1) == len(ids2):
|
|
||||||
return -1
|
|
||||||
return min(len(ids1), len(ids2))
|
|
||||||
|
|
||||||
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||||
if text1 == text2: # equal to TokenizerGroundtruth?
|
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue