refactor: Clean up constants and simplify the custom hf hub api
This commit is contained in:
parent
ce8524afd0
commit
5836d6c7e7
2 changed files with 161 additions and 106 deletions
|
@ -29,6 +29,7 @@ class GGUFMetadataKeys:
|
||||||
SOURCE_URL = "general.source.url"
|
SOURCE_URL = "general.source.url"
|
||||||
SOURCE_REPO = "general.source.repository"
|
SOURCE_REPO = "general.source.repository"
|
||||||
FILE_TYPE = "general.file_type"
|
FILE_TYPE = "general.file_type"
|
||||||
|
ENDIANESS = "general.endianess"
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
VOCAB_SIZE = "{arch}.vocab_size"
|
VOCAB_SIZE = "{arch}.vocab_size"
|
||||||
|
@ -90,7 +91,7 @@ class GGUFMetadataKeys:
|
||||||
BOS_ID = "tokenizer.bos_token_id"
|
BOS_ID = "tokenizer.bos_token_id"
|
||||||
EOS_ID = "tokenizer.eos_token_id"
|
EOS_ID = "tokenizer.eos_token_id"
|
||||||
UNK_ID = "tokenizer.unknown_token_id"
|
UNK_ID = "tokenizer.unknown_token_id"
|
||||||
SEP_ID = "tokenizer.seperator_token_id"
|
SEP_ID = "tokenizer.separator_token_id" # Fixed typo
|
||||||
PAD_ID = "tokenizer.padding_token_id"
|
PAD_ID = "tokenizer.padding_token_id"
|
||||||
CLS_ID = "tokenizer.cls_token_id"
|
CLS_ID = "tokenizer.cls_token_id"
|
||||||
MASK_ID = "tokenizer.mask_token_id"
|
MASK_ID = "tokenizer.mask_token_id"
|
||||||
|
@ -1038,6 +1039,19 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Model File Types
|
||||||
|
#
|
||||||
|
class ModelFileExtension(Enum):
|
||||||
|
PT = ".pt" # torch
|
||||||
|
PTH = ".pth" # torch
|
||||||
|
BIN = ".bin" # torch
|
||||||
|
SAFETENSORS = ".safetensors" # safetensors
|
||||||
|
JSON = ".json" # transformers/tokenizers
|
||||||
|
MODEL = ".model" # sentencepiece
|
||||||
|
GGUF = ".gguf" # ggml/llama.cpp
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Tokenizer Types
|
# Tokenizer Types
|
||||||
#
|
#
|
||||||
|
@ -1050,29 +1064,16 @@ class GGUFTokenType(IntEnum):
|
||||||
BYTE = 6
|
BYTE = 6
|
||||||
|
|
||||||
|
|
||||||
class GGUFTokenizerType(Enum):
|
class HFTokenizerType(Enum):
|
||||||
SPM = "SPM" # SentencePiece LLaMa tokenizer
|
SPM = "SPM" # SentencePiece LLaMa tokenizer
|
||||||
BPE = "BPE" # BytePair GPT-2 tokenizer
|
BPE = "BPE" # BytePair GPT-2 tokenizer
|
||||||
WPM = "WPM" # WordPiece BERT tokenizer
|
WPM = "WPM" # WordPiece BERT tokenizer
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Model File Types
|
|
||||||
#
|
|
||||||
class GGUFFileExtension(Enum):
|
|
||||||
PT = ".pt" # torch
|
|
||||||
PTH = ".pth" # torch
|
|
||||||
BIN = ".bin" # torch
|
|
||||||
SAFETENSORS = ".safetensors" # safetensors
|
|
||||||
JSON = ".json" # transformers/tokenizers
|
|
||||||
MODEL = ".model" # sentencepiece
|
|
||||||
GGUF = ".gguf" # ggml/llama.cpp
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Normalizer Types
|
# Normalizer Types
|
||||||
#
|
#
|
||||||
class GGUFNormalizerType(Enum):
|
class HFNormalizerType(Enum):
|
||||||
SEQUENCE = "Sequence"
|
SEQUENCE = "Sequence"
|
||||||
NFC = "NFC"
|
NFC = "NFC"
|
||||||
NFD = "NFD"
|
NFD = "NFD"
|
||||||
|
@ -1083,7 +1084,7 @@ class GGUFNormalizerType(Enum):
|
||||||
#
|
#
|
||||||
# Pre-tokenizer Types
|
# Pre-tokenizer Types
|
||||||
#
|
#
|
||||||
class GGUFPreTokenizerType(Enum):
|
class HFPreTokenizerType(Enum):
|
||||||
WHITESPACE = "Whitespace"
|
WHITESPACE = "Whitespace"
|
||||||
METASPACE = "Metaspace"
|
METASPACE = "Metaspace"
|
||||||
BYTE_LEVEL = "ByteLevel"
|
BYTE_LEVEL = "ByteLevel"
|
||||||
|
@ -1094,7 +1095,12 @@ class GGUFPreTokenizerType(Enum):
|
||||||
#
|
#
|
||||||
# HF Vocab Files
|
# HF Vocab Files
|
||||||
#
|
#
|
||||||
HF_TOKENIZER_BPE_FILES: tuple[str, ...] = ("config.json", "tokenizer_config.json", "tokenizer.json",)
|
HF_TOKENIZER_BPE_FILES = (
|
||||||
|
"config.json",
|
||||||
|
"tokenizer_config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
)
|
||||||
|
|
||||||
HF_TOKENIZER_SPM_FILES: tuple[str, ...] = HF_TOKENIZER_BPE_FILES + ("tokenizer.model",)
|
HF_TOKENIZER_SPM_FILES: tuple[str, ...] = HF_TOKENIZER_BPE_FILES + ("tokenizer.model",)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -1123,6 +1129,7 @@ KEY_GENERAL_LICENSE = GGUFMetadataKeys.General.LICENSE
|
||||||
KEY_GENERAL_SOURCE_URL = GGUFMetadataKeys.General.SOURCE_URL
|
KEY_GENERAL_SOURCE_URL = GGUFMetadataKeys.General.SOURCE_URL
|
||||||
KEY_GENERAL_SOURCE_REPO = GGUFMetadataKeys.General.SOURCE_REPO
|
KEY_GENERAL_SOURCE_REPO = GGUFMetadataKeys.General.SOURCE_REPO
|
||||||
KEY_GENERAL_FILE_TYPE = GGUFMetadataKeys.General.FILE_TYPE
|
KEY_GENERAL_FILE_TYPE = GGUFMetadataKeys.General.FILE_TYPE
|
||||||
|
KEY_GENERAL_ENDIANESS = GGUFMetadataKeys.General.ENDIANESS
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
KEY_VOCAB_SIZE = GGUFMetadataKeys.LLM.VOCAB_SIZE
|
KEY_VOCAB_SIZE = GGUFMetadataKeys.LLM.VOCAB_SIZE
|
||||||
|
|
|
@ -3,24 +3,20 @@ import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import login, model_info
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .constants import (
|
from .constants import HF_TOKENIZER_SPM_FILES
|
||||||
MODEL_TOKENIZER_BPE_FILES,
|
|
||||||
MODEL_TOKENIZER_SPM_FILES,
|
|
||||||
ModelFileExtension,
|
|
||||||
ModelNormalizerType,
|
|
||||||
ModelPreTokenizerType,
|
|
||||||
ModelTokenizerType,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HFHubBase:
|
class HFHubBase(Protocol):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_path: None | str | pathlib.Path, logger: None | logging.Logger
|
self,
|
||||||
|
model_path: None | str | pathlib.Path,
|
||||||
|
logger: None | logging.Logger,
|
||||||
):
|
):
|
||||||
# Set the model path
|
# Set the model path
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
|
@ -43,7 +39,7 @@ class HFHubBase:
|
||||||
def write_file(self, content: bytes, file_path: pathlib.Path) -> None:
|
def write_file(self, content: bytes, file_path: pathlib.Path) -> None:
|
||||||
with open(file_path, "wb") as file:
|
with open(file_path, "wb") as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
self.logger.info(f"Wrote {len(content)} bytes to {file_path} successfully")
|
self.logger.debug(f"Wrote {len(content)} bytes to {file_path} successfully")
|
||||||
|
|
||||||
|
|
||||||
class HFHubRequest(HFHubBase):
|
class HFHubRequest(HFHubBase):
|
||||||
|
@ -59,6 +55,11 @@ class HFHubRequest(HFHubBase):
|
||||||
if auth_token is None:
|
if auth_token is None:
|
||||||
self._headers = None
|
self._headers = None
|
||||||
else:
|
else:
|
||||||
|
# headers = {
|
||||||
|
# "Authorization": f"Bearer {auth_token}",
|
||||||
|
# "securityStatus": True,
|
||||||
|
# "blobs": True,
|
||||||
|
# }
|
||||||
self._headers = {"Authorization": f"Bearer {auth_token}"}
|
self._headers = {"Authorization": f"Bearer {auth_token}"}
|
||||||
|
|
||||||
# Persist across requests
|
# Persist across requests
|
||||||
|
@ -67,11 +68,12 @@ class HFHubRequest(HFHubBase):
|
||||||
# This is read-only
|
# This is read-only
|
||||||
self._base_url = "https://huggingface.co"
|
self._base_url = "https://huggingface.co"
|
||||||
|
|
||||||
# NOTE: Required for getting model_info
|
# NOTE: Cache repeat calls
|
||||||
login(auth_token, add_to_git_credential=True)
|
self._model_repo = None
|
||||||
|
self._model_files = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def headers(self) -> str:
|
def headers(self) -> None | dict[str, str]:
|
||||||
return self._headers
|
return self._headers
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -82,34 +84,79 @@ class HFHubRequest(HFHubBase):
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
return self._base_url
|
return self._base_url
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def list_remote_files(model_repo: str) -> list[str]:
|
|
||||||
# NOTE: Request repository metadata to extract remote filenames
|
|
||||||
return [x.rfilename for x in model_info(model_repo).siblings]
|
|
||||||
|
|
||||||
def list_filtered_remote_files(
|
|
||||||
self, model_repo: str, file_extension: ModelFileExtension
|
|
||||||
) -> list[str]:
|
|
||||||
model_files = []
|
|
||||||
self.logger.info(f"Repo:{model_repo}")
|
|
||||||
self.logger.debug(f"FileExtension:{file_extension.value}")
|
|
||||||
for filename in HFHubRequest.list_remote_files(model_repo):
|
|
||||||
suffix = pathlib.Path(filename).suffix
|
|
||||||
self.logger.debug(f"Suffix: {suffix}")
|
|
||||||
if suffix == file_extension.value:
|
|
||||||
self.logger.info(f"File: {filename}")
|
|
||||||
model_files.append(filename)
|
|
||||||
return model_files
|
|
||||||
|
|
||||||
def resolve_url(self, repo: str, filename: str) -> str:
|
def resolve_url(self, repo: str, filename: str) -> str:
|
||||||
return f"{self._base_url}/{repo}/resolve/main/{filename}"
|
return f"{self._base_url}/{repo}/resolve/main/{filename}"
|
||||||
|
|
||||||
def get_response(self, url: str) -> requests.Response:
|
def get_response(self, url: str) -> requests.Response:
|
||||||
|
# TODO: Stream requests and use tqdm to output the progress live
|
||||||
response = self._session.get(url, headers=self.headers)
|
response = self._session.get(url, headers=self.headers)
|
||||||
self.logger.info(f"Response status was {response.status_code}")
|
self.logger.debug(f"Response status was {response.status_code}")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def model_info(self, model_repo: str) -> dict[str, object]:
|
||||||
|
url = f"{self._base_url}/api/models/{model_repo}"
|
||||||
|
return self.get_response(url).json()
|
||||||
|
|
||||||
|
def list_remote_files(self, model_repo: str) -> list[str]:
|
||||||
|
# NOTE: Reset the cache if the repo changed
|
||||||
|
if self._model_repo != model_repo:
|
||||||
|
self._model_repo = model_repo
|
||||||
|
self._model_files = []
|
||||||
|
for f in self.model_info(self._model_repo)["siblings"]:
|
||||||
|
self._model_files.append(f["rfilename"])
|
||||||
|
dump = json.dumps(self._model_files, indent=4)
|
||||||
|
self.logger.debug(f"Cached remote files: {dump}")
|
||||||
|
# Return the cached file listing
|
||||||
|
return self._model_files
|
||||||
|
|
||||||
|
def list_filtered_remote_files(
|
||||||
|
self, model_repo: str, file_suffix: str
|
||||||
|
) -> list[str]:
|
||||||
|
model_files = []
|
||||||
|
self.logger.debug(f"Model Repo:{model_repo}")
|
||||||
|
self.logger.debug(f"File Suffix:{file_suffix}")
|
||||||
|
# NOTE: Valuable files are typically in the root path
|
||||||
|
for filename in self.list_remote_files(model_repo):
|
||||||
|
path = pathlib.Path(filename)
|
||||||
|
if len(path.parents) > 1:
|
||||||
|
continue # skip nested paths
|
||||||
|
self.logger.debug(f"Path Suffix: {path.suffix}")
|
||||||
|
if path.suffix == file_suffix:
|
||||||
|
self.logger.debug(f"File Name: {filename}")
|
||||||
|
model_files.append(filename)
|
||||||
|
return model_files
|
||||||
|
|
||||||
|
def list_remote_safetensors(self, model_repo: str) -> list[str]:
|
||||||
|
# NOTE: HuggingFace recommends using safetensors to mitigate pickled injections
|
||||||
|
return [
|
||||||
|
part
|
||||||
|
for part in self.list_filtered_remote_files(model_repo, ".safetensors")
|
||||||
|
if part.startswith("model")
|
||||||
|
]
|
||||||
|
|
||||||
|
def list_remote_bin(self, model_repo: str) -> list[str]:
|
||||||
|
# NOTE: HuggingFace is streamlining PyTorch models with the ".bin" extension
|
||||||
|
return [
|
||||||
|
part
|
||||||
|
for part in self.list_filtered_remote_files(model_repo, ".bin")
|
||||||
|
if part.startswith("pytorch_model")
|
||||||
|
]
|
||||||
|
|
||||||
|
def list_remote_weights(self, model_repo: str) -> list[str]:
|
||||||
|
model_parts = self.list_remote_safetensors(model_repo)
|
||||||
|
if not model_parts:
|
||||||
|
model_parts = self.list_remote_bin(model_repo)
|
||||||
|
self.logger.debug(f"Remote model parts: {model_parts}")
|
||||||
|
return model_parts
|
||||||
|
|
||||||
|
def list_remote_tokenizers(self, model_repo: str) -> list[str]:
|
||||||
|
return [
|
||||||
|
tok
|
||||||
|
for tok in self.list_remote_files(model_repo)
|
||||||
|
if tok in HF_TOKENIZER_SPM_FILES
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HFHubTokenizer(HFHubBase):
|
class HFHubTokenizer(HFHubBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -118,11 +165,8 @@ class HFHubTokenizer(HFHubBase):
|
||||||
super().__init__(model_path, logger)
|
super().__init__(model_path, logger)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_vocab_files(vocab_type: ModelTokenizerType) -> tuple[str, ...]:
|
def list_vocab_files() -> tuple[str, ...]:
|
||||||
if vocab_type == ModelTokenizerType.SPM.value:
|
return HF_TOKENIZER_SPM_FILES
|
||||||
return MODEL_TOKENIZER_SPM_FILES
|
|
||||||
# NOTE: WPM and BPE are equivalent
|
|
||||||
return MODEL_TOKENIZER_BPE_FILES
|
|
||||||
|
|
||||||
def model(self, model_repo: str) -> SentencePieceProcessor:
|
def model(self, model_repo: str) -> SentencePieceProcessor:
|
||||||
path = self.model_path / model_repo / "tokenizer.model"
|
path = self.model_path / model_repo / "tokenizer.model"
|
||||||
|
@ -216,58 +260,62 @@ class HFHubModel(HFHubBase):
|
||||||
|
|
||||||
def _request_single_file(
|
def _request_single_file(
|
||||||
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
||||||
) -> bool:
|
) -> None:
|
||||||
# NOTE: Consider optional `force` parameter if files need to be updated.
|
|
||||||
# e.g. The model creator updated the vocabulary to resolve an issue or add a feature.
|
|
||||||
if file_path.exists():
|
|
||||||
self.logger.info(f"skipped - downloaded {file_path} exists already.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# NOTE: Do not use bare exceptions! They mask issues!
|
# NOTE: Do not use bare exceptions! They mask issues!
|
||||||
# Allow the exception to occur or explicitly handle it.
|
# Allow the exception to occur or explicitly handle it.
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
|
||||||
resolved_url = self.request.resolve_url(model_repo, file_name)
|
resolved_url = self.request.resolve_url(model_repo, file_name)
|
||||||
response = self.request.get_response(resolved_url)
|
response = self.request.get_response(resolved_url)
|
||||||
self.write_file(response.content, file_path)
|
self.write_file(response.content, file_path)
|
||||||
self.logger.info(f"Model file successfully saved to {file_path}")
|
|
||||||
return True
|
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
self.logger.debug(f"Error while downloading '{file_name}': {str(e)}")
|
||||||
return False
|
|
||||||
|
|
||||||
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
|
def _request_listed_files(
|
||||||
for file_name in remote_files:
|
self, model_repo: str, remote_files: list[str, ...]
|
||||||
|
) -> None:
|
||||||
|
for file_name in tqdm(remote_files, total=len(remote_files)):
|
||||||
dir_path = self.model_path / model_repo
|
dir_path = self.model_path / model_repo
|
||||||
os.makedirs(dir_path, exist_ok=True)
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
|
||||||
|
# NOTE: Consider optional `force` parameter if files need to be updated.
|
||||||
|
# e.g. The model creator updated the vocabulary to resolve an issue or add a feature.
|
||||||
|
file_path = dir_path / file_name
|
||||||
|
if file_path.exists():
|
||||||
|
self.logger.debug(f"skipped - downloaded {file_path} exists already.")
|
||||||
|
continue # skip existing files
|
||||||
|
|
||||||
|
self.logger.debug(f"Downloading '{file_name}' from {model_repo}")
|
||||||
|
self._request_single_file(model_repo, file_name, file_path)
|
||||||
|
self.logger.debug(f"Model file successfully saved to {file_path}")
|
||||||
|
|
||||||
def config(self, model_repo: str) -> dict[str, object]:
|
def config(self, model_repo: str) -> dict[str, object]:
|
||||||
path = self.model_path / model_repo / "config.json"
|
path = self.model_path / model_repo / "config.json"
|
||||||
return json.loads(path.read_text(encoding="utf-8"))
|
return json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
def architecture(self, model_repo: str) -> str:
|
def architecture(self, model_repo: str) -> str:
|
||||||
config = self.config(model_repo)
|
|
||||||
# NOTE: Allow IndexError to be raised because something unexpected happened.
|
# NOTE: Allow IndexError to be raised because something unexpected happened.
|
||||||
# The general assumption is there is only a single architecture, but
|
# The general assumption is there is only a single architecture, but
|
||||||
# merged models may have multiple architecture types. This means this method
|
# merged models may have multiple architecture types. This means this method
|
||||||
# call is not guaranteed.
|
# call is not guaranteed.
|
||||||
return config.get("architectures", [])[0]
|
try:
|
||||||
|
return self.config(model_repo).get("architectures", [])[0]
|
||||||
|
except IndexError:
|
||||||
|
self.logger.debug(f"Failed to get {model_repo} architecture")
|
||||||
|
return str()
|
||||||
|
|
||||||
def download_model_files(
|
def download_model_weights(self, model_repo: str) -> None:
|
||||||
self, model_repo: str, file_extension: ModelFileExtension
|
remote_files = self.request.list_remote_weights(model_repo)
|
||||||
) -> None:
|
self._request_listed_files(model_repo, remote_files)
|
||||||
filtered_files = self.request.list_filtered_remote_files(
|
|
||||||
model_repo, file_extension
|
|
||||||
)
|
|
||||||
self._request_listed_files(model_repo, filtered_files)
|
|
||||||
|
|
||||||
def download_all_vocab_files(
|
def download_model_tokenizers(self, model_repo: str) -> None:
|
||||||
self, model_repo: str, vocab_type: ModelTokenizerType
|
remote_files = self.request.list_remote_tokenizers(model_repo)
|
||||||
) -> None:
|
self._request_listed_files(model_repo, remote_files)
|
||||||
vocab_files = self.tokenizer.list_vocab_files(vocab_type)
|
|
||||||
self._request_listed_files(model_repo, vocab_files)
|
|
||||||
|
|
||||||
def download_all_model_files(self, model_repo: str) -> None:
|
def download_model_weights_and_tokenizers(self, model_repo: str) -> None:
|
||||||
|
# attempt by priority
|
||||||
|
self.download_model_weights(model_repo)
|
||||||
|
self.download_model_tokenizers(model_repo)
|
||||||
|
|
||||||
|
def download_all_repository_files(self, model_repo: str) -> None:
|
||||||
all_files = self.request.list_remote_files(model_repo)
|
all_files = self.request.list_remote_files(model_repo)
|
||||||
self._request_listed_files(model_repo, all_files)
|
self._request_listed_files(model_repo, all_files)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue