refactor: Streamline method signatures and clarify method names related to downloading repo files
This commit is contained in:
parent
4438d052aa
commit
fda2319d7b
1 changed files with 9 additions and 11 deletions
|
@ -107,7 +107,7 @@ class HFHubRequest(HFHubBase):
|
|||
def resolve_url(self, repo: str, filename: str) -> str:
|
||||
return f"{self._base_url}/{repo}/resolve/main/{filename}"
|
||||
|
||||
def get(self, url: str) -> requests.Response:
|
||||
def get_response(self, url: str) -> requests.Response:
|
||||
response = self._session.get(url, headers=self.headers)
|
||||
self.logger.info(f"Response status was {response.status_code}")
|
||||
response.raise_for_status()
|
||||
|
@ -129,7 +129,7 @@ class HFHubModel(HFHubBase):
|
|||
def request(self) -> HFHubRequest:
|
||||
return self._request
|
||||
|
||||
def get_model_file(
|
||||
def _request_single_file(
|
||||
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
||||
) -> bool:
|
||||
# NOTE: Do not use bare exceptions! They mask issues!
|
||||
|
@ -137,33 +137,31 @@ class HFHubModel(HFHubBase):
|
|||
try:
|
||||
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
||||
resolved_url = self.request.resolve_url(model_repo, file_name)
|
||||
response = self.request.get(resolved_url)
|
||||
if not response.is_success():
|
||||
raise ValueError("Failed to download model file")
|
||||
self.request.write_file(response.content, file_path)
|
||||
response = self.request.get_response(resolved_url)
|
||||
self.write_file(response.content, file_path)
|
||||
self.logger.info(f"Model file successfully saved to {file_path}")
|
||||
return True
|
||||
except requests.exceptions.HTTPError as e:
|
||||
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
||||
return False
|
||||
|
||||
def _download_files(self, model_repo: str, remote_files: list[str]) -> None:
|
||||
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
|
||||
for file_name in remote_files:
|
||||
dir_path = self.model_path / model_repo
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
self.get_model_file(model_repo, file_name, dir_path / file_name)
|
||||
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
||||
|
||||
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
|
||||
filtered_files = self.list_filtered_remote_files(model_repo, file_type)
|
||||
self._download_files(model_repo, filtered_files)
|
||||
self._request_listed_files(model_repo, filtered_files)
|
||||
|
||||
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
|
||||
vocab_files = self.tokenizer.get_vocab_filenames(vocab_type)
|
||||
self._download_files(model_repo, vocab_files)
|
||||
self._request_listed_files(model_repo, vocab_files)
|
||||
|
||||
def download_all_model_files(self, model_repo: str) -> None:
|
||||
all_files = self.list_remote_files(model_repo)
|
||||
self._download_files(model_repo, all_files)
|
||||
self._request_listed_files(model_repo, all_files)
|
||||
|
||||
|
||||
class HFHubTokenizer(HFHubBase):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue