refactor: Add remote repository listings to the bas HFHub class
This commit is contained in:
parent
6da2bd6fbc
commit
168297f11c
1 changed files with 20 additions and 6 deletions
|
@ -90,6 +90,26 @@ class HFHub:
|
|||
# Set the hub api
|
||||
self._hub = HFHubRequest(auth_token, logger)
|
||||
|
||||
# NOTE: Required for getting model_info
|
||||
login(auth_token, add_to_git_credential=True)
|
||||
|
||||
@staticmethod
|
||||
def list_remote_files(model_repo: str) -> list[str]:
|
||||
# NOTE: Request repository metadata to extract remote filenames
|
||||
return [x.rfilename for x in model_info(model_repo).siblings]
|
||||
|
||||
def list_remote_model_files(self, model_repo: str, model_type: ModelFileType) -> list[str]:
|
||||
model_files = []
|
||||
self.logger.info(f"Repo:{model_repo}")
|
||||
self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[model_type]}:{model_type}")
|
||||
for filename in HFModel.list_remote_files(model_repo):
|
||||
suffix = pathlib.Path(filename).suffix
|
||||
self.logger.debug(f"Suffix: {suffix}")
|
||||
if suffix == MODEL_FILE_TYPE_NAMES[model_type]:
|
||||
self.logger.info(f"File: {filename}")
|
||||
model_files.append(filename)
|
||||
return model_files
|
||||
|
||||
@property
|
||||
def hub(self) -> HFHubRequest:
|
||||
return self._hub
|
||||
|
@ -196,12 +216,6 @@ class HFModel(HFHub):
|
|||
):
|
||||
super().__init__(model_path, auth_token, logger)
|
||||
self._tokenizer = HFTokenizer(model_path, auth_token, logger)
|
||||
login(auth_token) # NOTE: Required for using model_info
|
||||
|
||||
@staticmethod
|
||||
def get_model_info(repo_id: str) -> list[str]:
|
||||
# NOTE: Get repository metadata to extract remote filenames
|
||||
return [x.rfilename for x in model_info(repo_id).siblings]
|
||||
|
||||
@property
|
||||
def model_type(self) -> ModelFileType:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue