From 168297f11c8f1abb45052d116b031935e229ff20 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Fri, 24 May 2024 23:57:45 -0400 Subject: [PATCH] refactor: Add remote repository listings to the bas HFHub class --- gguf-py/gguf/huggingface_hub.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py index 0d1c8d54b..4be75ab51 100644 --- a/gguf-py/gguf/huggingface_hub.py +++ b/gguf-py/gguf/huggingface_hub.py @@ -90,6 +90,26 @@ class HFHub: # Set the hub api self._hub = HFHubRequest(auth_token, logger) + # NOTE: Required for getting model_info + login(auth_token, add_to_git_credential=True) + + @staticmethod + def list_remote_files(model_repo: str) -> list[str]: + # NOTE: Request repository metadata to extract remote filenames + return [x.rfilename for x in model_info(model_repo).siblings] + + def list_remote_model_files(self, model_repo: str, model_type: ModelFileType) -> list[str]: + model_files = [] + self.logger.info(f"Repo:{model_repo}") + self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[model_type]}:{model_type}") + for filename in HFModel.list_remote_files(model_repo): + suffix = pathlib.Path(filename).suffix + self.logger.debug(f"Suffix: {suffix}") + if suffix == MODEL_FILE_TYPE_NAMES[model_type]: + self.logger.info(f"File: {filename}") + model_files.append(filename) + return model_files + @property def hub(self) -> HFHubRequest: return self._hub @@ -196,12 +216,6 @@ class HFModel(HFHub): ): super().__init__(model_path, auth_token, logger) self._tokenizer = HFTokenizer(model_path, auth_token, logger) - login(auth_token) # NOTE: Required for using model_info - - @staticmethod - def get_model_info(repo_id: str) -> list[str]: - # NOTE: Get repository metadata to extract remote filenames - return [x.rfilename for x in model_info(repo_id).siblings] @property def model_type(self) -> ModelFileType: