refactor: Apply model schema to tokenizer downloads
- Add imports for json and hashlib - Add missing models: phi, stablelm, mistral, and mixtral - Fix constructor logic - Fix how models are accessed - Apply model schema to download_model method
This commit is contained in:
parent
f7515abf49
commit
302258721b
1 changed files with 15 additions and 6 deletions
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from enum import IntEnum, auto
|
||||
from hashlib import sha256
|
||||
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -42,6 +44,10 @@ MODELS = (
|
|||
{"tokt": TokenizerType.WPM, "repo": "jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||
{"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-es", },
|
||||
{"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-de", },
|
||||
{"tokt": TokenizerType.BPE, "repo": "microsoft/phi-1", },
|
||||
{"tokt": TokenizerType.BPE, "repo": "stabilityai/stablelm-2-zephyr-1_6b", },
|
||||
{"tokt": TokenizerType.SPM, "repo": "mistralai/Mistral-7B-Instruct-v0.2", },
|
||||
{"tokt": TokenizerType.SPM, "repo": "mistralai/Mixtral-8x7B-Instruct-v0.1", },
|
||||
)
|
||||
|
||||
|
||||
|
@ -103,10 +109,10 @@ class HFTokenizerRequest:
|
|||
auth_token: str,
|
||||
logger: None | logging.Logger
|
||||
):
|
||||
self._hub = HuggingFaceHub(auth_token, logger)
|
||||
|
||||
if dl_path is None:
|
||||
self._local_path = pathlib.Path("models/tokenizers")
|
||||
elif isinstance(dl_path, str):
|
||||
self._local_path = pathlib.Path(dl_path)
|
||||
else:
|
||||
self._local_path = dl_path
|
||||
|
||||
|
@ -116,13 +122,16 @@ class HFTokenizerRequest:
|
|||
logger = logging.getLogger("hf-tok-req")
|
||||
self.logger = logger
|
||||
|
||||
self._hub = HuggingFaceHub(auth_token, logger)
|
||||
self._models = list(MODELS)
|
||||
|
||||
@property
|
||||
def hub(self) -> HuggingFaceHub:
|
||||
return self._hub
|
||||
|
||||
@property
|
||||
def models(self) -> list[dict[str, object]]:
|
||||
return MODEL_REPOS
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def tokenizer_type(self) -> TokenizerType:
|
||||
|
@ -157,12 +166,12 @@ class HFTokenizerRequest:
|
|||
|
||||
def download_model(self) -> None:
|
||||
for model in self.models:
|
||||
os.makedirs(f"{self.local_path}/{model['name']}", exist_ok=True)
|
||||
os.makedirs(f"{self.local_path}/{model['repo']}", exist_ok=True)
|
||||
filenames = self.resolve_filenames(model['tokt'])
|
||||
for filename in filenames:
|
||||
filepath = pathlib.Path(f"{self.local_path}/{model['name']}/{filename}")
|
||||
filepath = pathlib.Path(f"{self.local_path}/{model['repo']}/{filename}")
|
||||
if filepath.is_file():
|
||||
self.logger.info(f"skipped pre-existing tokenizer {model['name']} at {filepath}")
|
||||
self.logger.info(f"skipped pre-existing tokenizer {model['repo']} in {filepath}")
|
||||
continue
|
||||
self.resolve_tokenizer_model(filename, filepath, model)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue