refactor: Apply model schema to tokenizer downloads

- Add imports for json and hashlib
- Add missing models: phi, stablelm, mistral, and mixtral
- Fix constructor logic
- Fix how models are accessed
- Apply model schema to download_model method
This commit is contained in:
teleprint-me 2024-05-18 01:26:39 -04:00
parent f7515abf49
commit 302258721b
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -1,7 +1,9 @@
import json
import logging import logging
import os import os
import pathlib import pathlib
from enum import IntEnum, auto from enum import IntEnum, auto
from hashlib import sha256
import requests import requests
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -42,6 +44,10 @@ MODELS = (
{"tokt": TokenizerType.WPM, "repo": "jinaai/jina-embeddings-v2-base-en", }, # WPM! {"tokt": TokenizerType.WPM, "repo": "jinaai/jina-embeddings-v2-base-en", }, # WPM!
{"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-es", }, {"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-es", },
{"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-de", }, {"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-de", },
{"tokt": TokenizerType.BPE, "repo": "microsoft/phi-1", },
{"tokt": TokenizerType.BPE, "repo": "stabilityai/stablelm-2-zephyr-1_6b", },
{"tokt": TokenizerType.SPM, "repo": "mistralai/Mistral-7B-Instruct-v0.2", },
{"tokt": TokenizerType.SPM, "repo": "mistralai/Mixtral-8x7B-Instruct-v0.1", },
) )
@ -103,10 +109,10 @@ class HFTokenizerRequest:
auth_token: str, auth_token: str,
logger: None | logging.Logger logger: None | logging.Logger
): ):
self._hub = HuggingFaceHub(auth_token, logger)
if dl_path is None: if dl_path is None:
self._local_path = pathlib.Path("models/tokenizers") self._local_path = pathlib.Path("models/tokenizers")
elif isinstance(dl_path, str):
self._local_path = pathlib.Path(dl_path)
else: else:
self._local_path = dl_path self._local_path = dl_path
@ -116,13 +122,16 @@ class HFTokenizerRequest:
logger = logging.getLogger("hf-tok-req") logger = logging.getLogger("hf-tok-req")
self.logger = logger self.logger = logger
self._hub = HuggingFaceHub(auth_token, logger)
self._models = list(MODELS)
@property @property
def hub(self) -> HuggingFaceHub: def hub(self) -> HuggingFaceHub:
return self._hub return self._hub
@property @property
def models(self) -> list[dict[str, object]]: def models(self) -> list[dict[str, object]]:
return MODEL_REPOS return self._models
@property @property
def tokenizer_type(self) -> TokenizerType: def tokenizer_type(self) -> TokenizerType:
@ -157,12 +166,12 @@ class HFTokenizerRequest:
def download_model(self) -> None: def download_model(self) -> None:
for model in self.models: for model in self.models:
os.makedirs(f"{self.local_path}/{model['name']}", exist_ok=True) os.makedirs(f"{self.local_path}/{model['repo']}", exist_ok=True)
filenames = self.resolve_filenames(model['tokt']) filenames = self.resolve_filenames(model['tokt'])
for filename in filenames: for filename in filenames:
filepath = pathlib.Path(f"{self.local_path}/{model['name']}/{filename}") filepath = pathlib.Path(f"{self.local_path}/{model['repo']}/{filename}")
if filepath.is_file(): if filepath.is_file():
self.logger.info(f"skipped pre-existing tokenizer {model['name']} at {filepath}") self.logger.info(f"skipped pre-existing tokenizer {model['repo']} in {filepath}")
continue continue
self.resolve_tokenizer_model(filename, filepath, model) self.resolve_tokenizer_model(filename, filepath, model)