From 1749209406287b0f9af8f54d18c25eaeb515e329 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Thu, 23 May 2024 20:50:15 -0400 Subject: [PATCH] refactor: Simplify huggingface hub api implementation --- gguf-py/gguf/huggingface_hub.py | 86 ++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py index 16774c927..8e0327e67 100644 --- a/gguf-py/gguf/huggingface_hub.py +++ b/gguf-py/gguf/huggingface_hub.py @@ -7,7 +7,15 @@ from hashlib import sha256 import requests from transformers import AutoTokenizer -from .constants import HF_MODEL_MAP, LLaMaModelType, LLaMaVocabType +from .constants import ( + GPT_PRE_TOKENIZER_DEFAULT, + HF_TOKENIZER_BPE_FILES, + HF_TOKENIZER_SPM_FILES, + MODEL_FILE_TYPE_NAMES, + VOCAB_TYPE_NAMES, + ModelFileType, + VocabType, +) class HFHubRequest: @@ -81,16 +89,11 @@ class HFHubBase: self.logger = logger self._hub = HFHubRequest(auth_token, logger) - self._models = list(HF_MODEL_MAP) @property def hub(self) -> HFHubRequest: return self._hub - @property - def models(self) -> list[dict[str, object]]: - return self._models - @property def model_path(self) -> pathlib.Path: return self._model_path @@ -103,45 +106,52 @@ class HFHubBase: class HFVocabRequest(HFHubBase): def __init__( self, - model_path: None | str | pathlib.Path, auth_token: str, + model_path: None | str | pathlib.Path, logger: None | logging.Logger ): super().__init__(model_path, auth_token, logger) @property - def tokenizer_type(self) -> LLaMaVocabType: - return LLaMaVocabType + def tokenizer_type(self) -> VocabType: + return VocabType - def resolve_filenames(self, tokt: LLaMaVocabType) -> tuple[str]: - filenames = ["config.json", "tokenizer_config.json", "tokenizer.json"] - if tokt == self.tokenizer_type.SPM: - filenames.append("tokenizer.model") - return tuple(filenames) + def get_vocab_name(self, vocab_type: VocabType) -> str: + return VOCAB_TYPE_NAMES.get(vocab_type) - def resolve_tokenizer_model( - self, - filename: str, - filepath: pathlib.Path, - model: dict[str, object] - ) -> None: - try: # NOTE: Do not use bare exceptions! They mask issues! - resolve_url = self.hub.resolve_url(model['repo'], filename) - response = self.hub.download_file(resolve_url) - self.hub.write_file(response.content, filepath) - except requests.exceptions.HTTPError as e: - self.logger.error(f"Failed to download tokenizer {model['repo']}: {e}") + def get_vocab_enum(self, vocab_name: str) -> VocabType: + return { + "SPM": VocabType.SPM, + "BPE": VocabType.BPE, + "WPM": VocabType.WPM, + }.get(vocab_name, VocabType.NON) - def download_models(self) -> None: - for model in self.models: - os.makedirs(f"{self.model_path}/{model['repo']}", exist_ok=True) - filenames = self.resolve_filenames(model['tokt']) - for filename in filenames: - filepath = pathlib.Path(f"{self.model_path}/{model['repo']}/{filename}") - if filepath.is_file(): - self.logger.info(f"skipped pre-existing tokenizer {model['repo']} in {filepath}") - continue - self.resolve_tokenizer_model(filename, filepath, model) + def get_vocab_filenames(self, vocab_type: VocabType) -> tuple[str]: + if vocab_type == self.tokenizer_type.SPM: + return HF_TOKENIZER_SPM_FILES + # NOTE: WPM and BPE are equivalent + return HF_TOKENIZER_BPE_FILES + + def get_vocab_file( + self, model_repo: str, file_name: str, file_path: pathlib.Path, + ) -> bool: + # NOTE: Do not use bare exceptions! They mask issues! + # Allow the exception to occur or handle it explicitly. + resolve_url = self.hub.resolve_url(model_repo, file_name) + response = self.hub.download_file(resolve_url) + self.hub.write_file(response.content, file_path) + self.logger.info(f"Downloaded tokenizer {file_name} from {model_repo}") + + def get_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: + vocab_list = self.get_vocab_filenames(vocab_type) + for vocab_file in vocab_list: + self.get_vocab_file(model_repo, vocab_file, self.model_path) + + def extract_normalizer(self) -> dict[str, object]: + pass + + def extract_pre_tokenizers(self) -> dict[str, object]: + pass def generate_checksums(self) -> None: checksums = [] @@ -191,5 +201,5 @@ class HFModelRequest(HFHubBase): super().__init__(model_path, auth_token, logger) @property - def model_type(self) -> LLaMaModelType: - return LLaMaModelType + def model_type(self) -> ModelFileType: + return ModelFileType