refactor: Clean up constants and simplify the custom hf hub api
This commit is contained in:
parent
ce8524afd0
commit
5836d6c7e7
2 changed files with 161 additions and 106 deletions
|
@ -29,6 +29,7 @@ class GGUFMetadataKeys:
|
|||
SOURCE_URL = "general.source.url"
|
||||
SOURCE_REPO = "general.source.repository"
|
||||
FILE_TYPE = "general.file_type"
|
||||
ENDIANESS = "general.endianess"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
|
@ -90,7 +91,7 @@ class GGUFMetadataKeys:
|
|||
BOS_ID = "tokenizer.bos_token_id"
|
||||
EOS_ID = "tokenizer.eos_token_id"
|
||||
UNK_ID = "tokenizer.unknown_token_id"
|
||||
SEP_ID = "tokenizer.seperator_token_id"
|
||||
SEP_ID = "tokenizer.separator_token_id" # Fixed typo
|
||||
PAD_ID = "tokenizer.padding_token_id"
|
||||
CLS_ID = "tokenizer.cls_token_id"
|
||||
MASK_ID = "tokenizer.mask_token_id"
|
||||
|
@ -1038,6 +1039,19 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
|||
}
|
||||
|
||||
|
||||
#
|
||||
# Model File Types
|
||||
#
|
||||
class ModelFileExtension(Enum):
|
||||
PT = ".pt" # torch
|
||||
PTH = ".pth" # torch
|
||||
BIN = ".bin" # torch
|
||||
SAFETENSORS = ".safetensors" # safetensors
|
||||
JSON = ".json" # transformers/tokenizers
|
||||
MODEL = ".model" # sentencepiece
|
||||
GGUF = ".gguf" # ggml/llama.cpp
|
||||
|
||||
|
||||
#
|
||||
# Tokenizer Types
|
||||
#
|
||||
|
@ -1050,29 +1064,16 @@ class GGUFTokenType(IntEnum):
|
|||
BYTE = 6
|
||||
|
||||
|
||||
class GGUFTokenizerType(Enum):
|
||||
class HFTokenizerType(Enum):
|
||||
SPM = "SPM" # SentencePiece LLaMa tokenizer
|
||||
BPE = "BPE" # BytePair GPT-2 tokenizer
|
||||
WPM = "WPM" # WordPiece BERT tokenizer
|
||||
|
||||
|
||||
#
|
||||
# Model File Types
|
||||
#
|
||||
class GGUFFileExtension(Enum):
|
||||
PT = ".pt" # torch
|
||||
PTH = ".pth" # torch
|
||||
BIN = ".bin" # torch
|
||||
SAFETENSORS = ".safetensors" # safetensors
|
||||
JSON = ".json" # transformers/tokenizers
|
||||
MODEL = ".model" # sentencepiece
|
||||
GGUF = ".gguf" # ggml/llama.cpp
|
||||
|
||||
|
||||
#
|
||||
# Normalizer Types
|
||||
#
|
||||
class GGUFNormalizerType(Enum):
|
||||
class HFNormalizerType(Enum):
|
||||
SEQUENCE = "Sequence"
|
||||
NFC = "NFC"
|
||||
NFD = "NFD"
|
||||
|
@ -1083,7 +1084,7 @@ class GGUFNormalizerType(Enum):
|
|||
#
|
||||
# Pre-tokenizer Types
|
||||
#
|
||||
class GGUFPreTokenizerType(Enum):
|
||||
class HFPreTokenizerType(Enum):
|
||||
WHITESPACE = "Whitespace"
|
||||
METASPACE = "Metaspace"
|
||||
BYTE_LEVEL = "ByteLevel"
|
||||
|
@ -1094,7 +1095,12 @@ class GGUFPreTokenizerType(Enum):
|
|||
#
|
||||
# HF Vocab Files
|
||||
#
|
||||
HF_TOKENIZER_BPE_FILES: tuple[str, ...] = ("config.json", "tokenizer_config.json", "tokenizer.json",)
|
||||
HF_TOKENIZER_BPE_FILES = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"tokenizer.json",
|
||||
)
|
||||
|
||||
HF_TOKENIZER_SPM_FILES: tuple[str, ...] = HF_TOKENIZER_BPE_FILES + ("tokenizer.model",)
|
||||
|
||||
#
|
||||
|
@ -1123,6 +1129,7 @@ KEY_GENERAL_LICENSE = GGUFMetadataKeys.General.LICENSE
|
|||
KEY_GENERAL_SOURCE_URL = GGUFMetadataKeys.General.SOURCE_URL
|
||||
KEY_GENERAL_SOURCE_REPO = GGUFMetadataKeys.General.SOURCE_REPO
|
||||
KEY_GENERAL_FILE_TYPE = GGUFMetadataKeys.General.FILE_TYPE
|
||||
KEY_GENERAL_ENDIANESS = GGUFMetadataKeys.General.ENDIANESS
|
||||
|
||||
# LLM
|
||||
KEY_VOCAB_SIZE = GGUFMetadataKeys.LLM.VOCAB_SIZE
|
||||
|
|
|
@ -3,24 +3,20 @@ import logging
|
|||
import os
|
||||
import pathlib
|
||||
from hashlib import sha256
|
||||
from typing import Protocol
|
||||
|
||||
import requests
|
||||
from huggingface_hub import login, model_info
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .constants import (
|
||||
MODEL_TOKENIZER_BPE_FILES,
|
||||
MODEL_TOKENIZER_SPM_FILES,
|
||||
ModelFileExtension,
|
||||
ModelNormalizerType,
|
||||
ModelPreTokenizerType,
|
||||
ModelTokenizerType,
|
||||
)
|
||||
from .constants import HF_TOKENIZER_SPM_FILES
|
||||
|
||||
|
||||
class HFHubBase:
|
||||
class HFHubBase(Protocol):
|
||||
def __init__(
|
||||
self, model_path: None | str | pathlib.Path, logger: None | logging.Logger
|
||||
self,
|
||||
model_path: None | str | pathlib.Path,
|
||||
logger: None | logging.Logger,
|
||||
):
|
||||
# Set the model path
|
||||
if model_path is None:
|
||||
|
@ -43,7 +39,7 @@ class HFHubBase:
|
|||
def write_file(self, content: bytes, file_path: pathlib.Path) -> None:
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(content)
|
||||
self.logger.info(f"Wrote {len(content)} bytes to {file_path} successfully")
|
||||
self.logger.debug(f"Wrote {len(content)} bytes to {file_path} successfully")
|
||||
|
||||
|
||||
class HFHubRequest(HFHubBase):
|
||||
|
@ -59,6 +55,11 @@ class HFHubRequest(HFHubBase):
|
|||
if auth_token is None:
|
||||
self._headers = None
|
||||
else:
|
||||
# headers = {
|
||||
# "Authorization": f"Bearer {auth_token}",
|
||||
# "securityStatus": True,
|
||||
# "blobs": True,
|
||||
# }
|
||||
self._headers = {"Authorization": f"Bearer {auth_token}"}
|
||||
|
||||
# Persist across requests
|
||||
|
@ -67,11 +68,12 @@ class HFHubRequest(HFHubBase):
|
|||
# This is read-only
|
||||
self._base_url = "https://huggingface.co"
|
||||
|
||||
# NOTE: Required for getting model_info
|
||||
login(auth_token, add_to_git_credential=True)
|
||||
# NOTE: Cache repeat calls
|
||||
self._model_repo = None
|
||||
self._model_files = None
|
||||
|
||||
@property
|
||||
def headers(self) -> str:
|
||||
def headers(self) -> None | dict[str, str]:
|
||||
return self._headers
|
||||
|
||||
@property
|
||||
|
@ -82,34 +84,79 @@ class HFHubRequest(HFHubBase):
|
|||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
||||
@staticmethod
|
||||
def list_remote_files(model_repo: str) -> list[str]:
|
||||
# NOTE: Request repository metadata to extract remote filenames
|
||||
return [x.rfilename for x in model_info(model_repo).siblings]
|
||||
|
||||
def list_filtered_remote_files(
|
||||
self, model_repo: str, file_extension: ModelFileExtension
|
||||
) -> list[str]:
|
||||
model_files = []
|
||||
self.logger.info(f"Repo:{model_repo}")
|
||||
self.logger.debug(f"FileExtension:{file_extension.value}")
|
||||
for filename in HFHubRequest.list_remote_files(model_repo):
|
||||
suffix = pathlib.Path(filename).suffix
|
||||
self.logger.debug(f"Suffix: {suffix}")
|
||||
if suffix == file_extension.value:
|
||||
self.logger.info(f"File: {filename}")
|
||||
model_files.append(filename)
|
||||
return model_files
|
||||
|
||||
def resolve_url(self, repo: str, filename: str) -> str:
|
||||
return f"{self._base_url}/{repo}/resolve/main/{filename}"
|
||||
|
||||
def get_response(self, url: str) -> requests.Response:
|
||||
# TODO: Stream requests and use tqdm to output the progress live
|
||||
response = self._session.get(url, headers=self.headers)
|
||||
self.logger.info(f"Response status was {response.status_code}")
|
||||
self.logger.debug(f"Response status was {response.status_code}")
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def model_info(self, model_repo: str) -> dict[str, object]:
|
||||
url = f"{self._base_url}/api/models/{model_repo}"
|
||||
return self.get_response(url).json()
|
||||
|
||||
def list_remote_files(self, model_repo: str) -> list[str]:
|
||||
# NOTE: Reset the cache if the repo changed
|
||||
if self._model_repo != model_repo:
|
||||
self._model_repo = model_repo
|
||||
self._model_files = []
|
||||
for f in self.model_info(self._model_repo)["siblings"]:
|
||||
self._model_files.append(f["rfilename"])
|
||||
dump = json.dumps(self._model_files, indent=4)
|
||||
self.logger.debug(f"Cached remote files: {dump}")
|
||||
# Return the cached file listing
|
||||
return self._model_files
|
||||
|
||||
def list_filtered_remote_files(
|
||||
self, model_repo: str, file_suffix: str
|
||||
) -> list[str]:
|
||||
model_files = []
|
||||
self.logger.debug(f"Model Repo:{model_repo}")
|
||||
self.logger.debug(f"File Suffix:{file_suffix}")
|
||||
# NOTE: Valuable files are typically in the root path
|
||||
for filename in self.list_remote_files(model_repo):
|
||||
path = pathlib.Path(filename)
|
||||
if len(path.parents) > 1:
|
||||
continue # skip nested paths
|
||||
self.logger.debug(f"Path Suffix: {path.suffix}")
|
||||
if path.suffix == file_suffix:
|
||||
self.logger.debug(f"File Name: {filename}")
|
||||
model_files.append(filename)
|
||||
return model_files
|
||||
|
||||
def list_remote_safetensors(self, model_repo: str) -> list[str]:
|
||||
# NOTE: HuggingFace recommends using safetensors to mitigate pickled injections
|
||||
return [
|
||||
part
|
||||
for part in self.list_filtered_remote_files(model_repo, ".safetensors")
|
||||
if part.startswith("model")
|
||||
]
|
||||
|
||||
def list_remote_bin(self, model_repo: str) -> list[str]:
|
||||
# NOTE: HuggingFace is streamlining PyTorch models with the ".bin" extension
|
||||
return [
|
||||
part
|
||||
for part in self.list_filtered_remote_files(model_repo, ".bin")
|
||||
if part.startswith("pytorch_model")
|
||||
]
|
||||
|
||||
def list_remote_weights(self, model_repo: str) -> list[str]:
|
||||
model_parts = self.list_remote_safetensors(model_repo)
|
||||
if not model_parts:
|
||||
model_parts = self.list_remote_bin(model_repo)
|
||||
self.logger.debug(f"Remote model parts: {model_parts}")
|
||||
return model_parts
|
||||
|
||||
def list_remote_tokenizers(self, model_repo: str) -> list[str]:
|
||||
return [
|
||||
tok
|
||||
for tok in self.list_remote_files(model_repo)
|
||||
if tok in HF_TOKENIZER_SPM_FILES
|
||||
]
|
||||
|
||||
|
||||
class HFHubTokenizer(HFHubBase):
|
||||
def __init__(
|
||||
|
@ -118,11 +165,8 @@ class HFHubTokenizer(HFHubBase):
|
|||
super().__init__(model_path, logger)
|
||||
|
||||
@staticmethod
|
||||
def list_vocab_files(vocab_type: ModelTokenizerType) -> tuple[str, ...]:
|
||||
if vocab_type == ModelTokenizerType.SPM.value:
|
||||
return MODEL_TOKENIZER_SPM_FILES
|
||||
# NOTE: WPM and BPE are equivalent
|
||||
return MODEL_TOKENIZER_BPE_FILES
|
||||
def list_vocab_files() -> tuple[str, ...]:
|
||||
return HF_TOKENIZER_SPM_FILES
|
||||
|
||||
def model(self, model_repo: str) -> SentencePieceProcessor:
|
||||
path = self.model_path / model_repo / "tokenizer.model"
|
||||
|
@ -216,58 +260,62 @@ class HFHubModel(HFHubBase):
|
|||
|
||||
def _request_single_file(
|
||||
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
||||
) -> bool:
|
||||
# NOTE: Consider optional `force` parameter if files need to be updated.
|
||||
# e.g. The model creator updated the vocabulary to resolve an issue or add a feature.
|
||||
if file_path.exists():
|
||||
self.logger.info(f"skipped - downloaded {file_path} exists already.")
|
||||
return False
|
||||
|
||||
) -> None:
|
||||
# NOTE: Do not use bare exceptions! They mask issues!
|
||||
# Allow the exception to occur or explicitly handle it.
|
||||
try:
|
||||
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
||||
resolved_url = self.request.resolve_url(model_repo, file_name)
|
||||
response = self.request.get_response(resolved_url)
|
||||
self.write_file(response.content, file_path)
|
||||
self.logger.info(f"Model file successfully saved to {file_path}")
|
||||
return True
|
||||
except requests.exceptions.HTTPError as e:
|
||||
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
||||
return False
|
||||
self.logger.debug(f"Error while downloading '{file_name}': {str(e)}")
|
||||
|
||||
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
|
||||
for file_name in remote_files:
|
||||
def _request_listed_files(
|
||||
self, model_repo: str, remote_files: list[str, ...]
|
||||
) -> None:
|
||||
for file_name in tqdm(remote_files, total=len(remote_files)):
|
||||
dir_path = self.model_path / model_repo
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
||||
|
||||
# NOTE: Consider optional `force` parameter if files need to be updated.
|
||||
# e.g. The model creator updated the vocabulary to resolve an issue or add a feature.
|
||||
file_path = dir_path / file_name
|
||||
if file_path.exists():
|
||||
self.logger.debug(f"skipped - downloaded {file_path} exists already.")
|
||||
continue # skip existing files
|
||||
|
||||
self.logger.debug(f"Downloading '{file_name}' from {model_repo}")
|
||||
self._request_single_file(model_repo, file_name, file_path)
|
||||
self.logger.debug(f"Model file successfully saved to {file_path}")
|
||||
|
||||
def config(self, model_repo: str) -> dict[str, object]:
|
||||
path = self.model_path / model_repo / "config.json"
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
def architecture(self, model_repo: str) -> str:
|
||||
config = self.config(model_repo)
|
||||
# NOTE: Allow IndexError to be raised because something unexpected happened.
|
||||
# The general assumption is there is only a single architecture, but
|
||||
# merged models may have multiple architecture types. This means this method
|
||||
# call is not guaranteed.
|
||||
return config.get("architectures", [])[0]
|
||||
try:
|
||||
return self.config(model_repo).get("architectures", [])[0]
|
||||
except IndexError:
|
||||
self.logger.debug(f"Failed to get {model_repo} architecture")
|
||||
return str()
|
||||
|
||||
def download_model_files(
|
||||
self, model_repo: str, file_extension: ModelFileExtension
|
||||
) -> None:
|
||||
filtered_files = self.request.list_filtered_remote_files(
|
||||
model_repo, file_extension
|
||||
)
|
||||
self._request_listed_files(model_repo, filtered_files)
|
||||
def download_model_weights(self, model_repo: str) -> None:
|
||||
remote_files = self.request.list_remote_weights(model_repo)
|
||||
self._request_listed_files(model_repo, remote_files)
|
||||
|
||||
def download_all_vocab_files(
|
||||
self, model_repo: str, vocab_type: ModelTokenizerType
|
||||
) -> None:
|
||||
vocab_files = self.tokenizer.list_vocab_files(vocab_type)
|
||||
self._request_listed_files(model_repo, vocab_files)
|
||||
def download_model_tokenizers(self, model_repo: str) -> None:
|
||||
remote_files = self.request.list_remote_tokenizers(model_repo)
|
||||
self._request_listed_files(model_repo, remote_files)
|
||||
|
||||
def download_all_model_files(self, model_repo: str) -> None:
|
||||
def download_model_weights_and_tokenizers(self, model_repo: str) -> None:
|
||||
# attempt by priority
|
||||
self.download_model_weights(model_repo)
|
||||
self.download_model_tokenizers(model_repo)
|
||||
|
||||
def download_all_repository_files(self, model_repo: str) -> None:
|
||||
all_files = self.request.list_remote_files(model_repo)
|
||||
self._request_listed_files(model_repo, all_files)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue