From fda2319d7be98d4d105ca2665db8a1340ec98898 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sat, 25 May 2024 03:32:27 -0400 Subject: [PATCH] refactor: Streamline method signatures and clarify method names related to downloading repo files --- gguf-py/gguf/huggingface_hub.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py index 1d97bf37f..869c84fd4 100644 --- a/gguf-py/gguf/huggingface_hub.py +++ b/gguf-py/gguf/huggingface_hub.py @@ -107,7 +107,7 @@ class HFHubRequest(HFHubBase): def resolve_url(self, repo: str, filename: str) -> str: return f"{self._base_url}/{repo}/resolve/main/{filename}" - def get(self, url: str) -> requests.Response: + def get_response(self, url: str) -> requests.Response: response = self._session.get(url, headers=self.headers) self.logger.info(f"Response status was {response.status_code}") response.raise_for_status() @@ -129,7 +129,7 @@ class HFHubModel(HFHubBase): def request(self) -> HFHubRequest: return self._request - def get_model_file( + def _request_single_file( self, model_repo: str, file_name: str, file_path: pathlib.Path ) -> bool: # NOTE: Do not use bare exceptions! They mask issues! @@ -137,33 +137,31 @@ class HFHubModel(HFHubBase): try: self.logger.info(f"Downloading '{file_name}' from {model_repo}") resolved_url = self.request.resolve_url(model_repo, file_name) - response = self.request.get(resolved_url) - if not response.is_success(): - raise ValueError("Failed to download model file") - self.request.write_file(response.content, file_path) + response = self.request.get_response(resolved_url) + self.write_file(response.content, file_path) self.logger.info(f"Model file successfully saved to {file_path}") return True except requests.exceptions.HTTPError as e: self.logger.error(f"Error while downloading '{file_name}': {str(e)}") return False - def _download_files(self, model_repo: str, remote_files: list[str]) -> None: + def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None: for file_name in remote_files: dir_path = self.model_path / model_repo os.makedirs(dir_path, exist_ok=True) - self.get_model_file(model_repo, file_name, dir_path / file_name) + self._request_single_file(model_repo, file_name, dir_path / file_name) def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None: filtered_files = self.list_filtered_remote_files(model_repo, file_type) - self._download_files(model_repo, filtered_files) + self._request_listed_files(model_repo, filtered_files) def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: vocab_files = self.tokenizer.get_vocab_filenames(vocab_type) - self._download_files(model_repo, vocab_files) + self._request_listed_files(model_repo, vocab_files) def download_all_model_files(self, model_repo: str) -> None: all_files = self.list_remote_files(model_repo) - self._download_files(model_repo, all_files) + self._request_listed_files(model_repo, all_files) class HFHubTokenizer(HFHubBase):