fix flake8 warnings

This commit is contained in:
wonjun Jang 2023-11-28 16:46:54 +09:00 committed by GitHub
parent 61edd1bc59
commit 1f5357cbcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -22,7 +22,7 @@ from collections import OrderedDict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar, Optional
import numpy as np import numpy as np
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -310,10 +310,8 @@ class VocabLoader:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()} except Exception:
except:
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True)
vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()}
self.added_tokens_dict: OrderedDict[str, int] = OrderedDict() self.added_tokens_dict: OrderedDict[str, int] = OrderedDict()
@ -423,15 +421,15 @@ class VocabLoader:
return "llama" return "llama"
path_candidates.append(path_candidate) path_candidates.append(path_candidate)
raise FileNotFoundError( raise FileNotFoundError(f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; if it's in another directory, pass the directory as --vocab-dir")
f"Could not find {find_candidates} in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir")
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>" return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>"
Vocab: TypeAlias = 'VocabLoader' Vocab: TypeAlias = 'VocabLoader'
# #
# data loading # data loading
# TODO: reuse (probably move to gguf.py?) # TODO: reuse (probably move to gguf.py?)
@ -806,6 +804,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
break break
yield result yield result
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
if params.n_vocab != vocab.vocab_size: if params.n_vocab != vocab.vocab_size:
if params.n_vocab == vocab.vocab_size: if params.n_vocab == vocab.vocab_size:
@ -907,11 +906,10 @@ class OutputFile:
self.gguf.close() self.gguf.close()
@staticmethod @staticmethod
def write_vocab_only( def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab,
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, svocab: gguf.SpecialVocab,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False, pad_vocab: bool = False) -> None:
) -> None:
check_vocab_size(params, vocab, pad_vocab = pad_vocab) check_vocab_size(params, vocab, pad_vocab = pad_vocab)
of = OutputFile(fname_out, endianess=endianess) of = OutputFile(fname_out, endianess=endianess)
@ -939,13 +937,11 @@ class OutputFile:
return dt.quantize(arr) return dt.quantize(arr)
@staticmethod @staticmethod
def write_all( def write_all(fname_out : Path, ftype: GGMLFileType, params: Params,
fname_out : Path, ftype: GGMLFileType, params: Params, model : LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
model : LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY,
concurrency: int = DEFAULT_CONCURRENCY, endianess : gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
endianess : gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab : bool = False) -> None:
pad_vocab : bool = False,
) -> None:
check_vocab_size(params, vocab, pad_vocab = pad_vocab) check_vocab_size(params, vocab, pad_vocab = pad_vocab)
of = OutputFile(fname_out, endianess=endianess) of = OutputFile(fname_out, endianess=endianess)
@ -1207,7 +1203,7 @@ def main(args_in: list[str] | None = None) -> None:
n_vocab = vocab.vocab_size) n_vocab = vocab.vocab_size)
outfile = args.outfile outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
endianess = endianess, pad_vocab = args.padvocab) endianess = endianess, pad_vocab = args.padvocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
return return
@ -1234,7 +1230,7 @@ def main(args_in: list[str] | None = None) -> None:
print(f"Writing {outfile}, format {ftype}") print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab) concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")