refactor: Streamline method signatures and clarify method names related to downloading repo files
This commit is contained in:
parent
4438d052aa
commit
fda2319d7b
1 changed files with 9 additions and 11 deletions
|
@ -107,7 +107,7 @@ class HFHubRequest(HFHubBase):
|
||||||
def resolve_url(self, repo: str, filename: str) -> str:
|
def resolve_url(self, repo: str, filename: str) -> str:
|
||||||
return f"{self._base_url}/{repo}/resolve/main/{filename}"
|
return f"{self._base_url}/{repo}/resolve/main/{filename}"
|
||||||
|
|
||||||
def get(self, url: str) -> requests.Response:
|
def get_response(self, url: str) -> requests.Response:
|
||||||
response = self._session.get(url, headers=self.headers)
|
response = self._session.get(url, headers=self.headers)
|
||||||
self.logger.info(f"Response status was {response.status_code}")
|
self.logger.info(f"Response status was {response.status_code}")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -129,7 +129,7 @@ class HFHubModel(HFHubBase):
|
||||||
def request(self) -> HFHubRequest:
|
def request(self) -> HFHubRequest:
|
||||||
return self._request
|
return self._request
|
||||||
|
|
||||||
def get_model_file(
|
def _request_single_file(
|
||||||
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# NOTE: Do not use bare exceptions! They mask issues!
|
# NOTE: Do not use bare exceptions! They mask issues!
|
||||||
|
@ -137,33 +137,31 @@ class HFHubModel(HFHubBase):
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
||||||
resolved_url = self.request.resolve_url(model_repo, file_name)
|
resolved_url = self.request.resolve_url(model_repo, file_name)
|
||||||
response = self.request.get(resolved_url)
|
response = self.request.get_response(resolved_url)
|
||||||
if not response.is_success():
|
self.write_file(response.content, file_path)
|
||||||
raise ValueError("Failed to download model file")
|
|
||||||
self.request.write_file(response.content, file_path)
|
|
||||||
self.logger.info(f"Model file successfully saved to {file_path}")
|
self.logger.info(f"Model file successfully saved to {file_path}")
|
||||||
return True
|
return True
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _download_files(self, model_repo: str, remote_files: list[str]) -> None:
|
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
|
||||||
for file_name in remote_files:
|
for file_name in remote_files:
|
||||||
dir_path = self.model_path / model_repo
|
dir_path = self.model_path / model_repo
|
||||||
os.makedirs(dir_path, exist_ok=True)
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
self.get_model_file(model_repo, file_name, dir_path / file_name)
|
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
||||||
|
|
||||||
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
|
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
|
||||||
filtered_files = self.list_filtered_remote_files(model_repo, file_type)
|
filtered_files = self.list_filtered_remote_files(model_repo, file_type)
|
||||||
self._download_files(model_repo, filtered_files)
|
self._request_listed_files(model_repo, filtered_files)
|
||||||
|
|
||||||
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
|
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
|
||||||
vocab_files = self.tokenizer.get_vocab_filenames(vocab_type)
|
vocab_files = self.tokenizer.get_vocab_filenames(vocab_type)
|
||||||
self._download_files(model_repo, vocab_files)
|
self._request_listed_files(model_repo, vocab_files)
|
||||||
|
|
||||||
def download_all_model_files(self, model_repo: str) -> None:
|
def download_all_model_files(self, model_repo: str) -> None:
|
||||||
all_files = self.list_remote_files(model_repo)
|
all_files = self.list_remote_files(model_repo)
|
||||||
self._download_files(model_repo, all_files)
|
self._request_listed_files(model_repo, all_files)
|
||||||
|
|
||||||
|
|
||||||
class HFHubTokenizer(HFHubBase):
|
class HFHubTokenizer(HFHubBase):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue