Update convert.py
This commit is contained in:
parent
9f4dc236a9
commit
dcf372e60e
1 changed files with 4 additions and 5 deletions
|
@ -345,8 +345,8 @@ class VocabLoader:
|
||||||
text = reverse_vocab[i].encode("utf-8")
|
text = reverse_vocab[i].encode("utf-8")
|
||||||
yield text, self.get_token_score(i), self.get_token_type(i)
|
yield text, self.get_token_score(i), self.get_token_type(i)
|
||||||
|
|
||||||
def get_token_type(self, token_id, default_type=gguf.TokenType.NORMAL):
|
def get_token_type(self, token_id):
|
||||||
toktype = default_type
|
toktype = gguf.TokenType.NORMAL
|
||||||
|
|
||||||
if self.spm is not None and token_id < self.spm.vocab_size():
|
if self.spm is not None and token_id < self.spm.vocab_size():
|
||||||
if self.spm.is_unknown(token_id):
|
if self.spm.is_unknown(token_id):
|
||||||
|
@ -372,17 +372,16 @@ class VocabLoader:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
default_toktype = gguf.TokenType.USER_DEFINED
|
|
||||||
|
|
||||||
for text in self.added_tokens_list:
|
for text in self.added_tokens_list:
|
||||||
|
|
||||||
if text in self.specials:
|
if text in self.specials:
|
||||||
|
|
||||||
toktype = self.get_token_type(self.specials[text], default_toktype)
|
toktype = self.get_token_type(self.specials[text])
|
||||||
score = self.get_token_score(self.specials[text])
|
score = self.get_token_score(self.specials[text])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
toktype = default_toktype
|
toktype = gguf.TokenType.USER_DEFINED
|
||||||
score = -1000.0
|
score = -1000.0
|
||||||
|
|
||||||
yield text.encode("utf-8"), score, toktype
|
yield text.encode("utf-8"), score, toktype
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue