refactor: Apply SoC for HF requests, vocab, and weights

This commit is contained in:
teleprint-me 2024-05-18 13:45:21 -04:00
parent 5eda2c9485
commit 2ef73ee6e4
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -51,7 +51,7 @@ MODELS = (
)
class HuggingFaceHub:
class HFHubRequest:
def __init__(self, auth_token: None | str, logger: None | logging.Logger):
# Set headers if authentication is available
if auth_token is None:
@ -102,19 +102,19 @@ class HuggingFaceHub:
return response
class HFTokenizerRequest:
class HFHubBase:
def __init__(
self,
dl_path: None | str | pathlib.Path,
model_path: None | str | pathlib.Path,
auth_token: str,
logger: None | logging.Logger
):
if dl_path is None:
self._local_path = pathlib.Path("models/tokenizers")
elif isinstance(dl_path, str):
self._local_path = pathlib.Path(dl_path)
if model_path is None:
self._model_path = pathlib.Path("models")
elif isinstance(model_path, str):
self._model_path = pathlib.Path(model_path)
else:
self._local_path = dl_path
self._model_path = model_path
# Set the logger
if logger is None:
@ -122,11 +122,11 @@ class HFTokenizerRequest:
logger = logging.getLogger("hf-tok-req")
self.logger = logger
self._hub = HuggingFaceHub(auth_token, logger)
self._hub = HFHubRequest(auth_token, logger)
self._models = list(MODELS)
@property
def hub(self) -> HuggingFaceHub:
def hub(self) -> HFHubRequest:
return self._hub
@property
@ -134,16 +134,26 @@ class HFTokenizerRequest:
return self._models
@property
def tokenizer_type(self) -> TokenizerType:
return TokenizerType
def model_path(self) -> pathlib.Path:
return self._model_path
@model_path.setter
def model_path(self, value: pathlib.Path):
self._model_path = value
class HFVocabRequest(HFHubBase):
def __init__(
self,
model_path: None | str | pathlib.Path,
auth_token: str,
logger: None | logging.Logger
):
super().__init__(model_path, auth_token, logger)
@property
def local_path(self) -> pathlib.Path:
return self._local_path
@local_path.setter
def local_path(self, value: pathlib.Path):
self._local_path = value
def tokenizer_type(self) -> TokenizerType:
return TokenizerType
def resolve_filenames(self, tokt: TokenizerType) -> tuple[str]:
filenames = ["config.json", "tokenizer_config.json", "tokenizer.json"]
@ -200,3 +210,14 @@ class HFTokenizerRequest:
if "ignore_merges" in cfg["model"]:
self.logger.info(f"ignore_merges: {json.dumps(cfg['model']['ignore_merges'], indent=4)}")
self.logger.info("")
# TODO:
class HFModelRequest(HFHubBase):
def __init__(
self,
model_path: None | str | pathlib.Path,
auth_token: str,
logger: None | logging.Logger
):
super().__init__(model_path, auth_token, logger)