feat: Attempt to mirror the llama.cpp API for compatibility
This commit is contained in:
parent
c6f2a48af7
commit
89a46fe818
1 changed files with 36 additions and 33 deletions
|
@ -8,14 +8,17 @@ from hashlib import sha256
|
||||||
import requests
|
import requests
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from .constants import MODEL_ARCH
|
||||||
class TokenizerType(IntEnum):
|
|
||||||
SPM = auto() # SentencePiece
|
|
||||||
BPE = auto() # BytePair
|
|
||||||
WPM = auto() # WordPiece
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(IntEnum):
|
class LLaMaVocabType(IntEnum):
|
||||||
|
NON = auto() # For models without vocab
|
||||||
|
SPM = auto() # SentencePiece LLaMa tokenizer
|
||||||
|
BPE = auto() # BytePair GPT-2 tokenizer
|
||||||
|
WPM = auto() # WordPiece BERT tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLaMaModelType(IntEnum):
|
||||||
PTH = auto() # PyTorch
|
PTH = auto() # PyTorch
|
||||||
SFT = auto() # SafeTensor
|
SFT = auto() # SafeTensor
|
||||||
|
|
||||||
|
@ -26,28 +29,28 @@ class ModelType(IntEnum):
|
||||||
# - Use architecture types because they are explicitly defined
|
# - Use architecture types because they are explicitly defined
|
||||||
# - Possible tokenizer model types are: SentencePiece, WordPiece, or BytePair
|
# - Possible tokenizer model types are: SentencePiece, WordPiece, or BytePair
|
||||||
MODELS = (
|
MODELS = (
|
||||||
{"tokt": TokenizerType.SPM, "repo": "meta-llama/Llama-2-7b-hf", },
|
{"arch": MODEL_ARCH.LLAMA, "vocab_type": LLaMaVocabType.SPM, "repo": "meta-llama/Llama-2-7b-hf", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "meta-llama/Meta-Llama-3-8B", },
|
{"arch": MODEL_ARCH.LLAMA, "vocab_type": LLaMaVocabType.BPE, "repo": "meta-llama/Meta-Llama-3-8B", },
|
||||||
{"tokt": TokenizerType.SPM, "repo": "microsoft/Phi-3-mini-4k-instruct", },
|
{"arch": MODEL_ARCH.PHI3, "vocab_type": LLaMaVocabType.SPM, "repo": "microsoft/Phi-3-mini-4k-instruct", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "deepseek-ai/deepseek-llm-7b-base", },
|
{"arch": None, "vocab_type": LLaMaVocabType.BPE, "repo": "deepseek-ai/deepseek-llm-7b-base", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "deepseek-ai/deepseek-coder-6.7b-base", },
|
{"arch": None, "vocab_type": LLaMaVocabType.BPE, "repo": "deepseek-ai/deepseek-coder-6.7b-base", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "tiiuae/falcon-7b", },
|
{"arch": MODEL_ARCH.FALCON, "vocab_type": LLaMaVocabType.BPE, "repo": "tiiuae/falcon-7b", },
|
||||||
{"tokt": TokenizerType.WPM, "repo": "BAAI/bge-small-en-v1.5", },
|
{"arch": None, "vocab_type": LLaMaVocabType.WPM, "repo": "BAAI/bge-small-en-v1.5", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "mosaicml/mpt-7b", },
|
{"arch": MODEL_ARCH.MPT, "vocab_type": LLaMaVocabType.BPE, "repo": "mosaicml/mpt-7b", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "bigcode/starcoder2-3b", },
|
{"arch": MODEL_ARCH.STARCODER2, "vocab_type": LLaMaVocabType.BPE, "repo": "bigcode/starcoder2-3b", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "openai-community/gpt2", },
|
{"arch": MODEL_ARCH.GPT2, "vocab_type": LLaMaVocabType.BPE, "repo": "openai-community/gpt2", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "smallcloudai/Refact-1_6-base", },
|
{"arch": MODEL_ARCH.REFACT, "vocab_type": LLaMaVocabType.BPE, "repo": "smallcloudai/Refact-1_6-base", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "CohereForAI/c4ai-command-r-v01", },
|
{"arch": MODEL_ARCH.COMMAND_R, "vocab_type": LLaMaVocabType.BPE, "repo": "CohereForAI/c4ai-command-r-v01", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "Qwen/Qwen1.5-7B", },
|
{"arch": MODEL_ARCH.QWEN2, "vocab_type": LLaMaVocabType.BPE, "repo": "Qwen/Qwen1.5-7B", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "allenai/OLMo-1.7-7B-hf", },
|
{"arch": MODEL_ARCH.OLMO, "vocab_type": LLaMaVocabType.BPE, "repo": "allenai/OLMo-1.7-7B-hf", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "databricks/dbrx-base", },
|
{"arch": MODEL_ARCH.DBRX, "vocab_type": LLaMaVocabType.BPE, "repo": "databricks/dbrx-base", },
|
||||||
{"tokt": TokenizerType.WPM, "repo": "jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
{"arch": MODEL_ARCH.JINA_BERT_V2, "vocab_type": LLaMaVocabType.WPM, "repo": "jinaai/jina-embeddings-v2-base-en", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-es", },
|
{"arch": MODEL_ARCH.JINA_BERT_V2, "vocab_type": LLaMaVocabType.BPE, "repo": "jinaai/jina-embeddings-v2-base-es", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "jinaai/jina-embeddings-v2-base-de", },
|
{"arch": MODEL_ARCH.JINA_BERT_V2, "vocab_type": LLaMaVocabType.BPE, "repo": "jinaai/jina-embeddings-v2-base-de", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "microsoft/phi-1", },
|
{"arch": MODEL_ARCH.PHI2, "vocab_type": LLaMaVocabType.BPE, "repo": "microsoft/phi-1", },
|
||||||
{"tokt": TokenizerType.BPE, "repo": "stabilityai/stablelm-2-zephyr-1_6b", },
|
{"arch": MODEL_ARCH.STABLELM, "vocab_type": LLaMaVocabType.BPE, "repo": "stabilityai/stablelm-2-zephyr-1_6b", },
|
||||||
{"tokt": TokenizerType.SPM, "repo": "mistralai/Mistral-7B-Instruct-v0.2", },
|
{"arch": MODEL_ARCH.LLAMA, "vocab_type": LLaMaVocabType.SPM, "repo": "mistralai/Mistral-7B-Instruct-v0.2", },
|
||||||
{"tokt": TokenizerType.SPM, "repo": "mistralai/Mixtral-8x7B-Instruct-v0.1", },
|
{"arch": MODEL_ARCH.LLAMA, "vocab_type": LLaMaVocabType.SPM, "repo": "mistralai/Mixtral-8x7B-Instruct-v0.1", },
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,10 +155,10 @@ class HFVocabRequest(HFHubBase):
|
||||||
super().__init__(model_path, auth_token, logger)
|
super().__init__(model_path, auth_token, logger)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tokenizer_type(self) -> TokenizerType:
|
def tokenizer_type(self) -> LLaMaVocabType:
|
||||||
return TokenizerType
|
return LLaMaVocabType
|
||||||
|
|
||||||
def resolve_filenames(self, tokt: TokenizerType) -> tuple[str]:
|
def resolve_filenames(self, tokt: LLaMaVocabType) -> tuple[str]:
|
||||||
filenames = ["config.json", "tokenizer_config.json", "tokenizer.json"]
|
filenames = ["config.json", "tokenizer_config.json", "tokenizer.json"]
|
||||||
if tokt == self.tokenizer_type.SPM:
|
if tokt == self.tokenizer_type.SPM:
|
||||||
filenames.append("tokenizer.model")
|
filenames.append("tokenizer.model")
|
||||||
|
@ -233,5 +236,5 @@ class HFModelRequest(HFHubBase):
|
||||||
super().__init__(model_path, auth_token, logger)
|
super().__init__(model_path, auth_token, logger)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> ModelType:
|
def model_type(self) -> LLaMaModelType:
|
||||||
return ModelType
|
return LLaMaModelType
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue