From 99275a1606c2bdcaca268f76e79e97b3f12de59f Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sat, 25 May 2024 02:10:52 -0400 Subject: [PATCH] refactor: Simplify API and merge HFModel into HFHub --- gguf-py/gguf/huggingface_hub.py | 93 ++++++++++++++++----------------- 1 file changed, 46 insertions(+), 47 deletions(-) diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py index 4be75ab51..519eb94ed 100644 --- a/gguf-py/gguf/huggingface_hub.py +++ b/gguf-py/gguf/huggingface_hub.py @@ -98,14 +98,16 @@ class HFHub: # NOTE: Request repository metadata to extract remote filenames return [x.rfilename for x in model_info(model_repo).siblings] - def list_remote_model_files(self, model_repo: str, model_type: ModelFileType) -> list[str]: + def filter_remote_files( + self, model_repo: str, file_type: ModelFileType + ) -> list[str]: model_files = [] self.logger.info(f"Repo:{model_repo}") - self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[model_type]}:{model_type}") - for filename in HFModel.list_remote_files(model_repo): + self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[file_type]}:{file_type}") + for filename in HFHub.list_remote_files(model_repo): suffix = pathlib.Path(filename).suffix self.logger.debug(f"Suffix: {suffix}") - if suffix == MODEL_FILE_TYPE_NAMES[model_type]: + if suffix == MODEL_FILE_TYPE_NAMES[file_type]: self.logger.info(f"File: {filename}") model_files.append(filename) return model_files @@ -122,6 +124,42 @@ class HFHub: def model_path(self, value: pathlib.Path): self._model_path = value + def get_model_file( + self, model_repo: str, file_name: str, file_path: pathlib.Path + ) -> bool: + # NOTE: Do not use bare exceptions! They mask issues! + # Allow the exception to occur or explicitly handle it. + try: + self.logger.info(f"Downloading '{file_name}' from {model_repo}") + resolve_url = self.hub.resolve_url(model_repo, file_name) + response = self.hub.download_file(resolve_url) + if not response.is_success(): + raise ValueError("Failed to download model file") + self.hub.write_file(response.content, file_path) + self.logger.info(f"Model file successfully saved to {file_path}") + return True + except requests.exceptions.HTTPError as e: + self.logger.error(f"Error while downloading '{file_name}': {str(e)}") + return False + + def _download_files(self, model_repo: str, remote_files: list[str]) -> None: + for file_name in remote_files: + dir_path = self.model_path / model_repo + os.makedirs(dir_path, exist_ok=True) + self.get_model_file(model_repo, file_name, dir_path / file_name) + + def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None: + filtered_files = self.filter_remote_files(model_repo, file_type) + self._download_files(model_repo, filtered_files) + + def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: + vocab_files = self.tokenizer.get_vocab_filenames(vocab_type) + self._download_files(model_repo, vocab_files) + + def download_all_model_files(self, model_repo: str) -> None: + all_files = self.list_remote_files(model_repo) + self._download_files(model_repo, all_files) + class HFTokenizer(HFHub): def __init__(self, model_path: str, auth_token: str, logger: logging.Logger): @@ -146,6 +184,10 @@ class HFTokenizer(HFHub): "WPM": VocabType.WPM, }.get(vocab_name, VocabType.NON) + @property + def default_pre_tokenizer(self) -> str: + return GPT_PRE_TOKENIZER_DEFAULT + def config(self, model_repo: str) -> dict[str, object]: path = self.model_path / model_repo / "config.json" return json.loads(path.read_text(encoding='utf-8')) @@ -205,46 +247,3 @@ class HFTokenizer(HFHub): for x, y in v.items(): if x not in ["vocab", "merges"]: self.logger.info(f"{k}:{x}:{json.dumps(y, indent=2)}") - - -class HFModel(HFHub): - def __init__( - self, - auth_token: str, - model_path: None | str | pathlib.Path, - logger: None | logging.Logger - ): - super().__init__(model_path, auth_token, logger) - self._tokenizer = HFTokenizer(model_path, auth_token, logger) - - @property - def model_type(self) -> ModelFileType: - return ModelFileType - - @property - def tokenizer(self) -> HFTokenizer: - return self._tokenizer - - def get_vocab_file( - self, model_repo: str, file_name: str, file_path: pathlib.Path, - ) -> bool: - # NOTE: Do not use bare exceptions! They mask issues! - # Allow the exception to occur or explicitly handle it. - resolve_url = self.hub.resolve_url(model_repo, file_name) - response = self.hub.download_file(resolve_url) - self.hub.write_file(response.content, file_path) - self.logger.info(f"Downloaded tokenizer {file_name} from {model_repo}") - - def get_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: - vocab_list = self.tokenizer.get_vocab_filenames(vocab_type) - for vocab_file in vocab_list: - dir_path = self.model_path / model_repo - file_path = dir_path / vocab_file - os.makedirs(dir_path, exist_ok=True) - self.get_vocab_file(model_repo, vocab_file, file_path) - - def get_model_file(self, model_repo: str, file_name: str, file_path: pathlib.Path) -> bool: - pass - - def get_all_model_files(self) -> None: - pass