refactor: Add remote repository listings to the bas HFHub class

This commit is contained in:
teleprint-me 2024-05-24 23:57:45 -04:00
parent 6da2bd6fbc
commit 168297f11c
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -90,6 +90,26 @@ class HFHub:
# Set the hub api
self._hub = HFHubRequest(auth_token, logger)
# NOTE: Required for getting model_info
login(auth_token, add_to_git_credential=True)
@staticmethod
def list_remote_files(model_repo: str) -> list[str]:
# NOTE: Request repository metadata to extract remote filenames
return [x.rfilename for x in model_info(model_repo).siblings]
def list_remote_model_files(self, model_repo: str, model_type: ModelFileType) -> list[str]:
model_files = []
self.logger.info(f"Repo:{model_repo}")
self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[model_type]}:{model_type}")
for filename in HFModel.list_remote_files(model_repo):
suffix = pathlib.Path(filename).suffix
self.logger.debug(f"Suffix: {suffix}")
if suffix == MODEL_FILE_TYPE_NAMES[model_type]:
self.logger.info(f"File: {filename}")
model_files.append(filename)
return model_files
@property
def hub(self) -> HFHubRequest:
return self._hub
@ -196,12 +216,6 @@ class HFModel(HFHub):
):
super().__init__(model_path, auth_token, logger)
self._tokenizer = HFTokenizer(model_path, auth_token, logger)
login(auth_token) # NOTE: Required for using model_info
@staticmethod
def get_model_info(repo_id: str) -> list[str]:
# NOTE: Get repository metadata to extract remote filenames
return [x.rfilename for x in model_info(repo_id).siblings]
@property
def model_type(self) -> ModelFileType: