refactor: Add remote repository listings to the bas HFHub class
This commit is contained in:
parent
6da2bd6fbc
commit
168297f11c
1 changed files with 20 additions and 6 deletions
|
@ -90,6 +90,26 @@ class HFHub:
|
||||||
# Set the hub api
|
# Set the hub api
|
||||||
self._hub = HFHubRequest(auth_token, logger)
|
self._hub = HFHubRequest(auth_token, logger)
|
||||||
|
|
||||||
|
# NOTE: Required for getting model_info
|
||||||
|
login(auth_token, add_to_git_credential=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_remote_files(model_repo: str) -> list[str]:
|
||||||
|
# NOTE: Request repository metadata to extract remote filenames
|
||||||
|
return [x.rfilename for x in model_info(model_repo).siblings]
|
||||||
|
|
||||||
|
def list_remote_model_files(self, model_repo: str, model_type: ModelFileType) -> list[str]:
|
||||||
|
model_files = []
|
||||||
|
self.logger.info(f"Repo:{model_repo}")
|
||||||
|
self.logger.debug(f"Match:{MODEL_FILE_TYPE_NAMES[model_type]}:{model_type}")
|
||||||
|
for filename in HFModel.list_remote_files(model_repo):
|
||||||
|
suffix = pathlib.Path(filename).suffix
|
||||||
|
self.logger.debug(f"Suffix: {suffix}")
|
||||||
|
if suffix == MODEL_FILE_TYPE_NAMES[model_type]:
|
||||||
|
self.logger.info(f"File: {filename}")
|
||||||
|
model_files.append(filename)
|
||||||
|
return model_files
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hub(self) -> HFHubRequest:
|
def hub(self) -> HFHubRequest:
|
||||||
return self._hub
|
return self._hub
|
||||||
|
@ -196,12 +216,6 @@ class HFModel(HFHub):
|
||||||
):
|
):
|
||||||
super().__init__(model_path, auth_token, logger)
|
super().__init__(model_path, auth_token, logger)
|
||||||
self._tokenizer = HFTokenizer(model_path, auth_token, logger)
|
self._tokenizer = HFTokenizer(model_path, auth_token, logger)
|
||||||
login(auth_token) # NOTE: Required for using model_info
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_info(repo_id: str) -> list[str]:
|
|
||||||
# NOTE: Get repository metadata to extract remote filenames
|
|
||||||
return [x.rfilename for x in model_info(repo_id).siblings]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> ModelFileType:
|
def model_type(self) -> ModelFileType:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue