From 4438d052aa55d50bb3cb42a316c1dc9b8f5d6ef5 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sat, 25 May 2024 02:57:59 -0400 Subject: [PATCH] refactor: Abstract file and logger management to streamline api interface --- gguf-py/gguf/huggingface_hub.py | 190 +++++++++++++++++--------------- 1 file changed, 100 insertions(+), 90 deletions(-) diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py index 519eb94ed..1d97bf37f 100644 --- a/gguf-py/gguf/huggingface_hub.py +++ b/gguf-py/gguf/huggingface_hub.py @@ -19,61 +19,10 @@ from .constants import ( ) -class HFHubRequest: - def __init__(self, auth_token: None | str, logger: None | logging.Logger): - # Set headers if authentication is available - if auth_token is None: - self._headers = None - else: - self._headers = {"Authorization": f"Bearer {auth_token}"} - - # Set the logger - if logger is None: - logger = logging.getLogger(__name__) - self.logger = logger - - # Persist across requests - self._session = requests.Session() - - # This is read-only - self._base_url = "https://huggingface.co" - - @property - def headers(self) -> str: - return self._headers - - @property - def save_path(self) -> pathlib.Path: - return self._save_path - - @property - def session(self) -> requests.Session: - return self._session - - @property - def base_url(self) -> str: - return self._base_url - - def write_file(self, content: bytes, file_path: pathlib.Path) -> None: - with open(file_path, 'wb') as f: - f.write(content) - self.logger.info(f"Wrote {len(content)} bytes to {file_path} successfully") - - def resolve_url(self, repo: str, filename: str) -> str: - return f"{self._base_url}/{repo}/resolve/main/{filename}" - - def download_file(self, url: str) -> requests.Response: - response = self._session.get(url, headers=self.headers) - self.logger.info(f"Response status was {response.status_code}") - response.raise_for_status() - return response - - -class HFHub: +class HFHubBase: def __init__( self, model_path: None | str | pathlib.Path, - auth_token: str, logger: None | logging.Logger ): # Set the model path @@ -83,39 +32,9 @@ class HFHub: # Set the logger if logger is None: - logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) self.logger = logger - # Set the hub api - self._hub = HFHubRequest(auth_token, logger) - - # NOTE: Required for getting model_info - login(auth_token, add_to_git_credential=True) - - @staticmethod - def list_remote_files(model_repo: str) -> list[str]: - # NOTE: Request repository metadata to extract remote filenames - return [x.rfilename for x in model_info(model_repo).siblings] - - def filter_remote_files( - self, model_repo: str, file_type: ModelFileType - ) -> list[str]: - model_files = [] - self.logger.info(f"Repo:{model_repo}") - self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[file_type]}:{file_type}") - for filename in HFHub.list_remote_files(model_repo): - suffix = pathlib.Path(filename).suffix - self.logger.debug(f"Suffix: {suffix}") - if suffix == MODEL_FILE_TYPE_NAMES[file_type]: - self.logger.info(f"File: {filename}") - model_files.append(filename) - return model_files - - @property - def hub(self) -> HFHubRequest: - return self._hub - @property def model_path(self) -> pathlib.Path: return pathlib.Path(self._model_path) @@ -124,6 +43,92 @@ class HFHub: def model_path(self, value: pathlib.Path): self._model_path = value + def write_file(self, content: bytes, file_path: pathlib.Path) -> None: + with open(file_path, 'wb') as file: + file.write(content) + self.logger.info(f"Wrote {len(content)} bytes to {file_path} successfully") + + +class HFHubRequest(HFHubBase): + def __init__( + self, + auth_token: None | str, + model_path: None | str | pathlib.Path, + logger: None | logging.Logger + ): + super().__init__(model_path, logger) + + # Set headers if authentication is available + if auth_token is None: + self._headers = None + else: + self._headers = {"Authorization": f"Bearer {auth_token}"} + + # Persist across requests + self._session = requests.Session() + + # This is read-only + self._base_url = "https://huggingface.co" + + # NOTE: Required for getting model_info + login(auth_token, add_to_git_credential=True) + + @property + def headers(self) -> str: + return self._headers + + @property + def session(self) -> requests.Session: + return self._session + + @property + def base_url(self) -> str: + return self._base_url + + @staticmethod + def list_remote_files(model_repo: str) -> list[str]: + # NOTE: Request repository metadata to extract remote filenames + return [x.rfilename for x in model_info(model_repo).siblings] + + def list_filtered_remote_files( + self, model_repo: str, file_type: ModelFileType + ) -> list[str]: + model_files = [] + self.logger.info(f"Repo:{model_repo}") + self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[file_type]}:{file_type}") + for filename in HFHubRequest.list_remote_files(model_repo): + suffix = pathlib.Path(filename).suffix + self.logger.debug(f"Suffix: {suffix}") + if suffix == MODEL_FILE_TYPE_NAMES[file_type]: + self.logger.info(f"File: {filename}") + model_files.append(filename) + return model_files + + def resolve_url(self, repo: str, filename: str) -> str: + return f"{self._base_url}/{repo}/resolve/main/{filename}" + + def get(self, url: str) -> requests.Response: + response = self._session.get(url, headers=self.headers) + self.logger.info(f"Response status was {response.status_code}") + response.raise_for_status() + return response + + +class HFHubModel(HFHubBase): + def __init__( + self, + auth_token: None | str, + model_path: None | str | pathlib.Path, + logger: None | logging.Logger + ): + super().__init__(model_path, logger) + + self._request = HFHubRequest(auth_token, model_path, logger) + + @property + def request(self) -> HFHubRequest: + return self._request + def get_model_file( self, model_repo: str, file_name: str, file_path: pathlib.Path ) -> bool: @@ -131,11 +136,11 @@ class HFHub: # Allow the exception to occur or explicitly handle it. try: self.logger.info(f"Downloading '{file_name}' from {model_repo}") - resolve_url = self.hub.resolve_url(model_repo, file_name) - response = self.hub.download_file(resolve_url) + resolved_url = self.request.resolve_url(model_repo, file_name) + response = self.request.get(resolved_url) if not response.is_success(): raise ValueError("Failed to download model file") - self.hub.write_file(response.content, file_path) + self.request.write_file(response.content, file_path) self.logger.info(f"Model file successfully saved to {file_path}") return True except requests.exceptions.HTTPError as e: @@ -149,7 +154,7 @@ class HFHub: self.get_model_file(model_repo, file_name, dir_path / file_name) def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None: - filtered_files = self.filter_remote_files(model_repo, file_type) + filtered_files = self.list_filtered_remote_files(model_repo, file_type) self._download_files(model_repo, filtered_files) def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: @@ -161,9 +166,13 @@ class HFHub: self._download_files(model_repo, all_files) -class HFTokenizer(HFHub): - def __init__(self, model_path: str, auth_token: str, logger: logging.Logger): - super().__init__(model_path, auth_token, logger) +class HFHubTokenizer(HFHubBase): + def __init__( + self, + model_path: None | str | pathlib.Path, + logger: None | logging.Logger + ): + super().__init__(model_path, logger) @staticmethod def get_vocab_filenames(vocab_type: VocabType) -> tuple[str]: @@ -215,9 +224,10 @@ class HFTokenizer(HFHub): return normalizer def get_pre_tokenizer(self, model_repo: str) -> None | dict[str, object]: - pre_tokenizer = self.tokenizer_json(model_repo).get("pre_tokenizer", dict()) + pre_tokenizer = self.tokenizer_json(model_repo).get("pre_tokenizer") if pre_tokenizer: self.logger.info(f"JSON:PreTokenizer: {json.dumps(pre_tokenizer, indent=2)}") + return pre_tokenizer else: self.logger.warn(f"WARN:PreTokenizer: {pre_tokenizer}") return pre_tokenizer