feat: Add static methods for resolving model types and model extensions

This commit is contained in:
teleprint-me 2024-05-25 19:11:56 -04:00
parent f30bd63252
commit da72554f58
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -219,6 +219,22 @@ class HFHubModel(HFHubBase):
self._request = HFHubRequest(auth_token, model_path, logger) self._request = HFHubRequest(auth_token, model_path, logger)
self._tokenizer = HFHubTokenizer(model_path, logger) self._tokenizer = HFHubTokenizer(model_path, logger)
@staticmethod
def get_model_type_name(model_type: ModelFileType) -> str:
return MODEL_FILE_TYPE_NAMES.get(model_type, "")
@staticmethod
def get_model_type(model_name: str) -> ModelFileType:
return {
".pt": ModelFileType.PT,
".pth": ModelFileType.PTH,
".bin": ModelFileType.BIN,
".safetensors": ModelFileType.SAFETENSORS,
".json": ModelFileType.JSON,
".model": ModelFileType.MODEL,
".gguf": ModelFileType.GGUF,
}.get(model_name, ModelFileType.NON)
@property @property
def request(self) -> HFHubRequest: def request(self) -> HFHubRequest:
return self._request return self._request