refactor: Add prototyped bridge interface for transformers and tokenizers
This commit is contained in:
parent
47ef6157a0
commit
c4470108ab
1 changed files with 35 additions and 34 deletions
|
@ -9,21 +9,18 @@ from huggingface_hub import login, model_info
|
|||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
from .constants import (
|
||||
GPT_PRE_TOKENIZER_DEFAULT,
|
||||
HF_TOKENIZER_BPE_FILES,
|
||||
HF_TOKENIZER_SPM_FILES,
|
||||
MODEL_TOKENIZER_BPE_FILES,
|
||||
MODEL_TOKENIZER_SPM_FILES,
|
||||
ModelFileExtension,
|
||||
NormalizerType,
|
||||
PreTokenizerType,
|
||||
VocabType,
|
||||
ModelNormalizerType,
|
||||
ModelPreTokenizerType,
|
||||
ModelTokenizerType,
|
||||
)
|
||||
|
||||
|
||||
class HFHubBase:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: None | str | pathlib.Path,
|
||||
logger: None | logging.Logger,
|
||||
self, model_path: None | str | pathlib.Path, logger: None | logging.Logger
|
||||
):
|
||||
# Set the model path
|
||||
if model_path is None:
|
||||
|
@ -116,43 +113,33 @@ class HFHubRequest(HFHubBase):
|
|||
|
||||
class HFHubTokenizer(HFHubBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: None | str | pathlib.Path,
|
||||
logger: None | logging.Logger,
|
||||
self, model_path: None | str | pathlib.Path, logger: None | logging.Logger
|
||||
):
|
||||
super().__init__(model_path, logger)
|
||||
|
||||
@staticmethod
|
||||
def list_vocab_files(vocab_type: VocabType) -> tuple[str]:
|
||||
if vocab_type == VocabType.SPM.value:
|
||||
return HF_TOKENIZER_SPM_FILES
|
||||
def list_vocab_files(vocab_type: ModelTokenizerType) -> tuple[str, ...]:
|
||||
if vocab_type == ModelTokenizerType.SPM.value:
|
||||
return MODEL_TOKENIZER_SPM_FILES
|
||||
# NOTE: WPM and BPE are equivalent
|
||||
return HF_TOKENIZER_BPE_FILES
|
||||
return MODEL_TOKENIZER_BPE_FILES
|
||||
|
||||
@property
|
||||
def default_pre_tokenizer(self) -> tuple[str, ...]:
|
||||
return GPT_PRE_TOKENIZER_DEFAULT
|
||||
|
||||
def config(self, model_repo: str) -> dict[str, object]:
|
||||
path = self.model_path / model_repo / "config.json"
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
def tokenizer_model(self, model_repo: str) -> SentencePieceProcessor:
|
||||
def model(self, model_repo: str) -> SentencePieceProcessor:
|
||||
path = self.model_path / model_repo / "tokenizer.model"
|
||||
processor = SentencePieceProcessor()
|
||||
processor.LoadFromFile(path.read_bytes())
|
||||
return processor
|
||||
|
||||
def tokenizer_config(self, model_repo: str) -> dict[str, object]:
|
||||
def config(self, model_repo: str) -> dict[str, object]:
|
||||
path = self.model_path / model_repo / "tokenizer_config.json"
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
def tokenizer_json(self, model_repo: str) -> dict[str, object]:
|
||||
def json(self, model_repo: str) -> dict[str, object]:
|
||||
path = self.model_path / model_repo / "tokenizer.json"
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
def get_normalizer(self, model_repo: str) -> None | dict[str, object]:
|
||||
normalizer = self.tokenizer_json(model_repo).get("normalizer", dict())
|
||||
normalizer = self.json(model_repo).get("normalizer", dict())
|
||||
if normalizer:
|
||||
self.logger.info(f"JSON:Normalizer: {json.dumps(normalizer, indent=2)}")
|
||||
else:
|
||||
|
@ -160,7 +147,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
return normalizer
|
||||
|
||||
def get_pre_tokenizer(self, model_repo: str) -> None | dict[str, object]:
|
||||
pre_tokenizer = self.tokenizer_json(model_repo).get("pre_tokenizer")
|
||||
pre_tokenizer = self.json(model_repo).get("pre_tokenizer")
|
||||
if pre_tokenizer:
|
||||
self.logger.info(
|
||||
f"JSON:PreTokenizer: {json.dumps(pre_tokenizer, indent=2)}"
|
||||
|
@ -171,7 +158,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
return pre_tokenizer
|
||||
|
||||
def get_added_tokens(self, model_repo: str) -> None | list[dict[str, object]]:
|
||||
added_tokens = self.tokenizer_json(model_repo).get("added_tokens", list())
|
||||
added_tokens = self.json(model_repo).get("added_tokens", list())
|
||||
if added_tokens:
|
||||
self.logger.info(f"JSON:AddedTokens: {json.dumps(added_tokens, indent=2)}")
|
||||
else:
|
||||
|
@ -179,7 +166,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
return added_tokens
|
||||
|
||||
def get_pre_tokenizer_json_hash(self, model_repo: str) -> None | str:
|
||||
tokenizer = self.tokenizer_json(model_repo)
|
||||
tokenizer = self.json(model_repo)
|
||||
tokenizer_path = self.model_path / model_repo / "tokenizer.json"
|
||||
if tokenizer.get("pre_tokenizer"):
|
||||
sha256sum = sha256(str(tokenizer.get("pre_tokenizer")).encode()).hexdigest()
|
||||
|
@ -189,7 +176,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
return sha256sum
|
||||
|
||||
def get_tokenizer_json_hash(self, model_repo: str) -> str:
|
||||
tokenizer = self.tokenizer_json(model_repo)
|
||||
tokenizer = self.json(model_repo)
|
||||
tokenizer_path = self.model_path / model_repo / "tokenizer.json"
|
||||
sha256sum = sha256(str(tokenizer).encode()).hexdigest()
|
||||
self.logger.info(f"Hashed '{tokenizer_path}' as {sha256sum}")
|
||||
|
@ -197,7 +184,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
|
||||
def log_tokenizer_json_info(self, model_repo: str) -> None:
|
||||
self.logger.info(f"{model_repo}")
|
||||
tokenizer = self.tokenizer_json(model_repo)
|
||||
tokenizer = self.json(model_repo)
|
||||
for k, v in tokenizer.items():
|
||||
if k not in ["added_tokens", "model"]:
|
||||
self.logger.info(f"{k}:{json.dumps(v, indent=2)}")
|
||||
|
@ -255,6 +242,18 @@ class HFHubModel(HFHubBase):
|
|||
os.makedirs(dir_path, exist_ok=True)
|
||||
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
||||
|
||||
def config(self, model_repo: str) -> dict[str, object]:
|
||||
path = self.model_path / model_repo / "config.json"
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
def architecture(self, model_repo: str) -> str:
|
||||
config = self.config(model_repo)
|
||||
# NOTE: Allow IndexError to be raised because something unexpected happened.
|
||||
# The general assumption is there is only a single architecture, but
|
||||
# merged models may have multiple architecture types. This means this method
|
||||
# call is not guaranteed.
|
||||
return config.get("architectures", [])[0]
|
||||
|
||||
def download_model_files(
|
||||
self, model_repo: str, file_extension: ModelFileExtension
|
||||
) -> None:
|
||||
|
@ -263,7 +262,9 @@ class HFHubModel(HFHubBase):
|
|||
)
|
||||
self._request_listed_files(model_repo, filtered_files)
|
||||
|
||||
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
|
||||
def download_all_vocab_files(
|
||||
self, model_repo: str, vocab_type: ModelTokenizerType
|
||||
) -> None:
|
||||
vocab_files = self.tokenizer.list_vocab_files(vocab_type)
|
||||
self._request_listed_files(model_repo, vocab_files)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue