refactor: Apply SoC for HF requests, vocab, and weights

This commit is contained in:
teleprint-me 2024-05-18 13:45:21 -04:00
parent 5eda2c9485
commit 2ef73ee6e4
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -51,7 +51,7 @@ MODELS = (
) )
class HuggingFaceHub: class HFHubRequest:
def __init__(self, auth_token: None | str, logger: None | logging.Logger): def __init__(self, auth_token: None | str, logger: None | logging.Logger):
# Set headers if authentication is available # Set headers if authentication is available
if auth_token is None: if auth_token is None:
@ -102,19 +102,19 @@ class HuggingFaceHub:
return response return response
class HFTokenizerRequest: class HFHubBase:
def __init__( def __init__(
self, self,
dl_path: None | str | pathlib.Path, model_path: None | str | pathlib.Path,
auth_token: str, auth_token: str,
logger: None | logging.Logger logger: None | logging.Logger
): ):
if dl_path is None: if model_path is None:
self._local_path = pathlib.Path("models/tokenizers") self._model_path = pathlib.Path("models")
elif isinstance(dl_path, str): elif isinstance(model_path, str):
self._local_path = pathlib.Path(dl_path) self._model_path = pathlib.Path(model_path)
else: else:
self._local_path = dl_path self._model_path = model_path
# Set the logger # Set the logger
if logger is None: if logger is None:
@ -122,11 +122,11 @@ class HFTokenizerRequest:
logger = logging.getLogger("hf-tok-req") logger = logging.getLogger("hf-tok-req")
self.logger = logger self.logger = logger
self._hub = HuggingFaceHub(auth_token, logger) self._hub = HFHubRequest(auth_token, logger)
self._models = list(MODELS) self._models = list(MODELS)
@property @property
def hub(self) -> HuggingFaceHub: def hub(self) -> HFHubRequest:
return self._hub return self._hub
@property @property
@ -134,16 +134,26 @@ class HFTokenizerRequest:
return self._models return self._models
@property @property
def tokenizer_type(self) -> TokenizerType: def model_path(self) -> pathlib.Path:
return TokenizerType return self._model_path
@model_path.setter
def model_path(self, value: pathlib.Path):
self._model_path = value
class HFVocabRequest(HFHubBase):
def __init__(
self,
model_path: None | str | pathlib.Path,
auth_token: str,
logger: None | logging.Logger
):
super().__init__(model_path, auth_token, logger)
@property @property
def local_path(self) -> pathlib.Path: def tokenizer_type(self) -> TokenizerType:
return self._local_path return TokenizerType
@local_path.setter
def local_path(self, value: pathlib.Path):
self._local_path = value
def resolve_filenames(self, tokt: TokenizerType) -> tuple[str]: def resolve_filenames(self, tokt: TokenizerType) -> tuple[str]:
filenames = ["config.json", "tokenizer_config.json", "tokenizer.json"] filenames = ["config.json", "tokenizer_config.json", "tokenizer.json"]
@ -200,3 +210,14 @@ class HFTokenizerRequest:
if "ignore_merges" in cfg["model"]: if "ignore_merges" in cfg["model"]:
self.logger.info(f"ignore_merges: {json.dumps(cfg['model']['ignore_merges'], indent=4)}") self.logger.info(f"ignore_merges: {json.dumps(cfg['model']['ignore_merges'], indent=4)}")
self.logger.info("") self.logger.info("")
# TODO:
class HFModelRequest(HFHubBase):
def __init__(
self,
model_path: None | str | pathlib.Path,
auth_token: str,
logger: None | logging.Logger
):
super().__init__(model_path, auth_token, logger)