refactor: Streamline method signatures and clarify method names related to downloading repo files

This commit is contained in:
teleprint-me 2024-05-25 03:32:27 -04:00
parent 4438d052aa
commit fda2319d7b
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -107,7 +107,7 @@ class HFHubRequest(HFHubBase):
def resolve_url(self, repo: str, filename: str) -> str: def resolve_url(self, repo: str, filename: str) -> str:
return f"{self._base_url}/{repo}/resolve/main/{filename}" return f"{self._base_url}/{repo}/resolve/main/{filename}"
def get(self, url: str) -> requests.Response: def get_response(self, url: str) -> requests.Response:
response = self._session.get(url, headers=self.headers) response = self._session.get(url, headers=self.headers)
self.logger.info(f"Response status was {response.status_code}") self.logger.info(f"Response status was {response.status_code}")
response.raise_for_status() response.raise_for_status()
@ -129,7 +129,7 @@ class HFHubModel(HFHubBase):
def request(self) -> HFHubRequest: def request(self) -> HFHubRequest:
return self._request return self._request
def get_model_file( def _request_single_file(
self, model_repo: str, file_name: str, file_path: pathlib.Path self, model_repo: str, file_name: str, file_path: pathlib.Path
) -> bool: ) -> bool:
# NOTE: Do not use bare exceptions! They mask issues! # NOTE: Do not use bare exceptions! They mask issues!
@ -137,33 +137,31 @@ class HFHubModel(HFHubBase):
try: try:
self.logger.info(f"Downloading '{file_name}' from {model_repo}") self.logger.info(f"Downloading '{file_name}' from {model_repo}")
resolved_url = self.request.resolve_url(model_repo, file_name) resolved_url = self.request.resolve_url(model_repo, file_name)
response = self.request.get(resolved_url) response = self.request.get_response(resolved_url)
if not response.is_success(): self.write_file(response.content, file_path)
raise ValueError("Failed to download model file")
self.request.write_file(response.content, file_path)
self.logger.info(f"Model file successfully saved to {file_path}") self.logger.info(f"Model file successfully saved to {file_path}")
return True return True
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
self.logger.error(f"Error while downloading '{file_name}': {str(e)}") self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
return False return False
def _download_files(self, model_repo: str, remote_files: list[str]) -> None: def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
for file_name in remote_files: for file_name in remote_files:
dir_path = self.model_path / model_repo dir_path = self.model_path / model_repo
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
self.get_model_file(model_repo, file_name, dir_path / file_name) self._request_single_file(model_repo, file_name, dir_path / file_name)
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None: def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
filtered_files = self.list_filtered_remote_files(model_repo, file_type) filtered_files = self.list_filtered_remote_files(model_repo, file_type)
self._download_files(model_repo, filtered_files) self._request_listed_files(model_repo, filtered_files)
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
vocab_files = self.tokenizer.get_vocab_filenames(vocab_type) vocab_files = self.tokenizer.get_vocab_filenames(vocab_type)
self._download_files(model_repo, vocab_files) self._request_listed_files(model_repo, vocab_files)
def download_all_model_files(self, model_repo: str) -> None: def download_all_model_files(self, model_repo: str) -> None:
all_files = self.list_remote_files(model_repo) all_files = self.list_remote_files(model_repo)
self._download_files(model_repo, all_files) self._request_listed_files(model_repo, all_files)
class HFHubTokenizer(HFHubBase): class HFHubTokenizer(HFHubBase):