refactor: Clean up constants and simplify the custom hf hub api

This commit is contained in:
teleprint-me 2024-06-02 16:27:21 -04:00
parent ce8524afd0
commit 5836d6c7e7
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48
2 changed files with 161 additions and 106 deletions

View file

@ -29,6 +29,7 @@ class GGUFMetadataKeys:
SOURCE_URL = "general.source.url" SOURCE_URL = "general.source.url"
SOURCE_REPO = "general.source.repository" SOURCE_REPO = "general.source.repository"
FILE_TYPE = "general.file_type" FILE_TYPE = "general.file_type"
ENDIANESS = "general.endianess"
class LLM: class LLM:
VOCAB_SIZE = "{arch}.vocab_size" VOCAB_SIZE = "{arch}.vocab_size"
@ -77,20 +78,20 @@ class GGUFMetadataKeys:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
class Tokenizer: class Tokenizer:
MODEL = "tokenizer.model" # STRING: e.g. llama, gpt2, etc... MODEL = "tokenizer.model" # STRING: e.g. llama, gpt2, etc...
TYPE = "tokenizer.type" # STRING: BPE, SPM, WPM, etc. TYPE = "tokenizer.type" # STRING: BPE, SPM, WPM, etc.
NORM = "tokenizer.norm" # OBJECT {"type": "ByteLevel", ...} NORM = "tokenizer.norm" # OBJECT {"type": "ByteLevel", ...}
PRE = "tokenizer.pre" # OBJECT {"type": "ByteLevel", ...} PRE = "tokenizer.pre" # OBJECT {"type": "ByteLevel", ...}
ADDED = "tokenizer.added" # ARRAY of OBJECTs: [{"id": 1, ...}, ...] ADDED = "tokenizer.added" # ARRAY of OBJECTs: [{"id": 1, ...}, ...]
VOCAB = "tokenizer.vocab" # ARRAY of STRINGs: ["[BOS]", ...] VOCAB = "tokenizer.vocab" # ARRAY of STRINGs: ["[BOS]", ...]
MERGES = "tokenizer.merges" # ARRAY of STRINGs: ["▁ t", ...] MERGES = "tokenizer.merges" # ARRAY of STRINGs: ["▁ t", ...]
TOKEN_TYPE = "tokenizer.token_type" # ARRAY of INT [2, ...] TOKEN_TYPE = "tokenizer.token_type" # ARRAY of INT [2, ...]
TOKEN_TYPE_COUNT = "tokenizer.token_type_count" # BERT token types TOKEN_TYPE_COUNT = "tokenizer.token_type_count" # BERT token types
SCORES = "tokenizer.scores" # WPM only SCORES = "tokenizer.scores" # WPM only
BOS_ID = "tokenizer.bos_token_id" BOS_ID = "tokenizer.bos_token_id"
EOS_ID = "tokenizer.eos_token_id" EOS_ID = "tokenizer.eos_token_id"
UNK_ID = "tokenizer.unknown_token_id" UNK_ID = "tokenizer.unknown_token_id"
SEP_ID = "tokenizer.seperator_token_id" SEP_ID = "tokenizer.separator_token_id" # Fixed typo
PAD_ID = "tokenizer.padding_token_id" PAD_ID = "tokenizer.padding_token_id"
CLS_ID = "tokenizer.cls_token_id" CLS_ID = "tokenizer.cls_token_id"
MASK_ID = "tokenizer.mask_token_id" MASK_ID = "tokenizer.mask_token_id"
@ -1038,6 +1039,19 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
} }
#
# Model File Types
#
class ModelFileExtension(Enum):
PT = ".pt" # torch
PTH = ".pth" # torch
BIN = ".bin" # torch
SAFETENSORS = ".safetensors" # safetensors
JSON = ".json" # transformers/tokenizers
MODEL = ".model" # sentencepiece
GGUF = ".gguf" # ggml/llama.cpp
# #
# Tokenizer Types # Tokenizer Types
# #
@ -1050,51 +1064,43 @@ class GGUFTokenType(IntEnum):
BYTE = 6 BYTE = 6
class GGUFTokenizerType(Enum): class HFTokenizerType(Enum):
SPM = "SPM" # SentencePiece LLaMa tokenizer SPM = "SPM" # SentencePiece LLaMa tokenizer
BPE = "BPE" # BytePair GPT-2 tokenizer BPE = "BPE" # BytePair GPT-2 tokenizer
WPM = "WPM" # WordPiece BERT tokenizer WPM = "WPM" # WordPiece BERT tokenizer
#
# Model File Types
#
class GGUFFileExtension(Enum):
PT = ".pt" # torch
PTH = ".pth" # torch
BIN = ".bin" # torch
SAFETENSORS = ".safetensors" # safetensors
JSON = ".json" # transformers/tokenizers
MODEL = ".model" # sentencepiece
GGUF = ".gguf" # ggml/llama.cpp
# #
# Normalizer Types # Normalizer Types
# #
class GGUFNormalizerType(Enum): class HFNormalizerType(Enum):
SEQUENCE = "Sequence" SEQUENCE = "Sequence"
NFC = "NFC" NFC = "NFC"
NFD = "NFD" NFD = "NFD"
NFKC = "NFKC" NFKC = "NFKC"
NFKD = "NFKD" NFKD = "NFKD"
# #
# Pre-tokenizer Types # Pre-tokenizer Types
# #
class GGUFPreTokenizerType(Enum): class HFPreTokenizerType(Enum):
WHITESPACE = "Whitespace" WHITESPACE = "Whitespace"
METASPACE = "Metaspace" METASPACE = "Metaspace"
BYTE_LEVEL = "ByteLevel" BYTE_LEVEL = "ByteLevel"
BERT_PRE_TOKENIZER = "BertPreTokenizer" BERT_PRE_TOKENIZER = "BertPreTokenizer"
SEQUENCE = "Sequence" SEQUENCE = "Sequence"
# #
# HF Vocab Files # HF Vocab Files
# #
HF_TOKENIZER_BPE_FILES: tuple[str, ...] = ("config.json", "tokenizer_config.json", "tokenizer.json",) HF_TOKENIZER_BPE_FILES = (
"config.json",
"tokenizer_config.json",
"tokenizer.json",
)
HF_TOKENIZER_SPM_FILES: tuple[str, ...] = HF_TOKENIZER_BPE_FILES + ("tokenizer.model",) HF_TOKENIZER_SPM_FILES: tuple[str, ...] = HF_TOKENIZER_BPE_FILES + ("tokenizer.model",)
# #
@ -1123,6 +1129,7 @@ KEY_GENERAL_LICENSE = GGUFMetadataKeys.General.LICENSE
KEY_GENERAL_SOURCE_URL = GGUFMetadataKeys.General.SOURCE_URL KEY_GENERAL_SOURCE_URL = GGUFMetadataKeys.General.SOURCE_URL
KEY_GENERAL_SOURCE_REPO = GGUFMetadataKeys.General.SOURCE_REPO KEY_GENERAL_SOURCE_REPO = GGUFMetadataKeys.General.SOURCE_REPO
KEY_GENERAL_FILE_TYPE = GGUFMetadataKeys.General.FILE_TYPE KEY_GENERAL_FILE_TYPE = GGUFMetadataKeys.General.FILE_TYPE
KEY_GENERAL_ENDIANESS = GGUFMetadataKeys.General.ENDIANESS
# LLM # LLM
KEY_VOCAB_SIZE = GGUFMetadataKeys.LLM.VOCAB_SIZE KEY_VOCAB_SIZE = GGUFMetadataKeys.LLM.VOCAB_SIZE

View file

@ -3,24 +3,20 @@ import logging
import os import os
import pathlib import pathlib
from hashlib import sha256 from hashlib import sha256
from typing import Protocol
import requests import requests
from huggingface_hub import login, model_info
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
from .constants import ( from .constants import HF_TOKENIZER_SPM_FILES
MODEL_TOKENIZER_BPE_FILES,
MODEL_TOKENIZER_SPM_FILES,
ModelFileExtension,
ModelNormalizerType,
ModelPreTokenizerType,
ModelTokenizerType,
)
class HFHubBase: class HFHubBase(Protocol):
def __init__( def __init__(
self, model_path: None | str | pathlib.Path, logger: None | logging.Logger self,
model_path: None | str | pathlib.Path,
logger: None | logging.Logger,
): ):
# Set the model path # Set the model path
if model_path is None: if model_path is None:
@ -43,7 +39,7 @@ class HFHubBase:
def write_file(self, content: bytes, file_path: pathlib.Path) -> None: def write_file(self, content: bytes, file_path: pathlib.Path) -> None:
with open(file_path, "wb") as file: with open(file_path, "wb") as file:
file.write(content) file.write(content)
self.logger.info(f"Wrote {len(content)} bytes to {file_path} successfully") self.logger.debug(f"Wrote {len(content)} bytes to {file_path} successfully")
class HFHubRequest(HFHubBase): class HFHubRequest(HFHubBase):
@ -59,6 +55,11 @@ class HFHubRequest(HFHubBase):
if auth_token is None: if auth_token is None:
self._headers = None self._headers = None
else: else:
# headers = {
# "Authorization": f"Bearer {auth_token}",
# "securityStatus": True,
# "blobs": True,
# }
self._headers = {"Authorization": f"Bearer {auth_token}"} self._headers = {"Authorization": f"Bearer {auth_token}"}
# Persist across requests # Persist across requests
@ -67,11 +68,12 @@ class HFHubRequest(HFHubBase):
# This is read-only # This is read-only
self._base_url = "https://huggingface.co" self._base_url = "https://huggingface.co"
# NOTE: Required for getting model_info # NOTE: Cache repeat calls
login(auth_token, add_to_git_credential=True) self._model_repo = None
self._model_files = None
@property @property
def headers(self) -> str: def headers(self) -> None | dict[str, str]:
return self._headers return self._headers
@property @property
@ -82,34 +84,79 @@ class HFHubRequest(HFHubBase):
def base_url(self) -> str: def base_url(self) -> str:
return self._base_url return self._base_url
@staticmethod
def list_remote_files(model_repo: str) -> list[str]:
# NOTE: Request repository metadata to extract remote filenames
return [x.rfilename for x in model_info(model_repo).siblings]
def list_filtered_remote_files(
self, model_repo: str, file_extension: ModelFileExtension
) -> list[str]:
model_files = []
self.logger.info(f"Repo:{model_repo}")
self.logger.debug(f"FileExtension:{file_extension.value}")
for filename in HFHubRequest.list_remote_files(model_repo):
suffix = pathlib.Path(filename).suffix
self.logger.debug(f"Suffix: {suffix}")
if suffix == file_extension.value:
self.logger.info(f"File: {filename}")
model_files.append(filename)
return model_files
def resolve_url(self, repo: str, filename: str) -> str: def resolve_url(self, repo: str, filename: str) -> str:
return f"{self._base_url}/{repo}/resolve/main/{filename}" return f"{self._base_url}/{repo}/resolve/main/{filename}"
def get_response(self, url: str) -> requests.Response: def get_response(self, url: str) -> requests.Response:
# TODO: Stream requests and use tqdm to output the progress live
response = self._session.get(url, headers=self.headers) response = self._session.get(url, headers=self.headers)
self.logger.info(f"Response status was {response.status_code}") self.logger.debug(f"Response status was {response.status_code}")
response.raise_for_status() response.raise_for_status()
return response return response
def model_info(self, model_repo: str) -> dict[str, object]:
url = f"{self._base_url}/api/models/{model_repo}"
return self.get_response(url).json()
def list_remote_files(self, model_repo: str) -> list[str]:
# NOTE: Reset the cache if the repo changed
if self._model_repo != model_repo:
self._model_repo = model_repo
self._model_files = []
for f in self.model_info(self._model_repo)["siblings"]:
self._model_files.append(f["rfilename"])
dump = json.dumps(self._model_files, indent=4)
self.logger.debug(f"Cached remote files: {dump}")
# Return the cached file listing
return self._model_files
def list_filtered_remote_files(
self, model_repo: str, file_suffix: str
) -> list[str]:
model_files = []
self.logger.debug(f"Model Repo:{model_repo}")
self.logger.debug(f"File Suffix:{file_suffix}")
# NOTE: Valuable files are typically in the root path
for filename in self.list_remote_files(model_repo):
path = pathlib.Path(filename)
if len(path.parents) > 1:
continue # skip nested paths
self.logger.debug(f"Path Suffix: {path.suffix}")
if path.suffix == file_suffix:
self.logger.debug(f"File Name: {filename}")
model_files.append(filename)
return model_files
def list_remote_safetensors(self, model_repo: str) -> list[str]:
# NOTE: HuggingFace recommends using safetensors to mitigate pickled injections
return [
part
for part in self.list_filtered_remote_files(model_repo, ".safetensors")
if part.startswith("model")
]
def list_remote_bin(self, model_repo: str) -> list[str]:
# NOTE: HuggingFace is streamlining PyTorch models with the ".bin" extension
return [
part
for part in self.list_filtered_remote_files(model_repo, ".bin")
if part.startswith("pytorch_model")
]
def list_remote_weights(self, model_repo: str) -> list[str]:
model_parts = self.list_remote_safetensors(model_repo)
if not model_parts:
model_parts = self.list_remote_bin(model_repo)
self.logger.debug(f"Remote model parts: {model_parts}")
return model_parts
def list_remote_tokenizers(self, model_repo: str) -> list[str]:
return [
tok
for tok in self.list_remote_files(model_repo)
if tok in HF_TOKENIZER_SPM_FILES
]
class HFHubTokenizer(HFHubBase): class HFHubTokenizer(HFHubBase):
def __init__( def __init__(
@ -118,11 +165,8 @@ class HFHubTokenizer(HFHubBase):
super().__init__(model_path, logger) super().__init__(model_path, logger)
@staticmethod @staticmethod
def list_vocab_files(vocab_type: ModelTokenizerType) -> tuple[str, ...]: def list_vocab_files() -> tuple[str, ...]:
if vocab_type == ModelTokenizerType.SPM.value: return HF_TOKENIZER_SPM_FILES
return MODEL_TOKENIZER_SPM_FILES
# NOTE: WPM and BPE are equivalent
return MODEL_TOKENIZER_BPE_FILES
def model(self, model_repo: str) -> SentencePieceProcessor: def model(self, model_repo: str) -> SentencePieceProcessor:
path = self.model_path / model_repo / "tokenizer.model" path = self.model_path / model_repo / "tokenizer.model"
@ -216,58 +260,62 @@ class HFHubModel(HFHubBase):
def _request_single_file( def _request_single_file(
self, model_repo: str, file_name: str, file_path: pathlib.Path self, model_repo: str, file_name: str, file_path: pathlib.Path
) -> bool: ) -> None:
# NOTE: Consider optional `force` parameter if files need to be updated.
# e.g. The model creator updated the vocabulary to resolve an issue or add a feature.
if file_path.exists():
self.logger.info(f"skipped - downloaded {file_path} exists already.")
return False
# NOTE: Do not use bare exceptions! They mask issues! # NOTE: Do not use bare exceptions! They mask issues!
# Allow the exception to occur or explicitly handle it. # Allow the exception to occur or explicitly handle it.
try: try:
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
resolved_url = self.request.resolve_url(model_repo, file_name) resolved_url = self.request.resolve_url(model_repo, file_name)
response = self.request.get_response(resolved_url) response = self.request.get_response(resolved_url)
self.write_file(response.content, file_path) self.write_file(response.content, file_path)
self.logger.info(f"Model file successfully saved to {file_path}")
return True
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
self.logger.error(f"Error while downloading '{file_name}': {str(e)}") self.logger.debug(f"Error while downloading '{file_name}': {str(e)}")
return False
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None: def _request_listed_files(
for file_name in remote_files: self, model_repo: str, remote_files: list[str, ...]
) -> None:
for file_name in tqdm(remote_files, total=len(remote_files)):
dir_path = self.model_path / model_repo dir_path = self.model_path / model_repo
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
self._request_single_file(model_repo, file_name, dir_path / file_name)
# NOTE: Consider optional `force` parameter if files need to be updated.
# e.g. The model creator updated the vocabulary to resolve an issue or add a feature.
file_path = dir_path / file_name
if file_path.exists():
self.logger.debug(f"skipped - downloaded {file_path} exists already.")
continue # skip existing files
self.logger.debug(f"Downloading '{file_name}' from {model_repo}")
self._request_single_file(model_repo, file_name, file_path)
self.logger.debug(f"Model file successfully saved to {file_path}")
def config(self, model_repo: str) -> dict[str, object]: def config(self, model_repo: str) -> dict[str, object]:
path = self.model_path / model_repo / "config.json" path = self.model_path / model_repo / "config.json"
return json.loads(path.read_text(encoding="utf-8")) return json.loads(path.read_text(encoding="utf-8"))
def architecture(self, model_repo: str) -> str: def architecture(self, model_repo: str) -> str:
config = self.config(model_repo)
# NOTE: Allow IndexError to be raised because something unexpected happened. # NOTE: Allow IndexError to be raised because something unexpected happened.
# The general assumption is there is only a single architecture, but # The general assumption is there is only a single architecture, but
# merged models may have multiple architecture types. This means this method # merged models may have multiple architecture types. This means this method
# call is not guaranteed. # call is not guaranteed.
return config.get("architectures", [])[0] try:
return self.config(model_repo).get("architectures", [])[0]
except IndexError:
self.logger.debug(f"Failed to get {model_repo} architecture")
return str()
def download_model_files( def download_model_weights(self, model_repo: str) -> None:
self, model_repo: str, file_extension: ModelFileExtension remote_files = self.request.list_remote_weights(model_repo)
) -> None: self._request_listed_files(model_repo, remote_files)
filtered_files = self.request.list_filtered_remote_files(
model_repo, file_extension
)
self._request_listed_files(model_repo, filtered_files)
def download_all_vocab_files( def download_model_tokenizers(self, model_repo: str) -> None:
self, model_repo: str, vocab_type: ModelTokenizerType remote_files = self.request.list_remote_tokenizers(model_repo)
) -> None: self._request_listed_files(model_repo, remote_files)
vocab_files = self.tokenizer.list_vocab_files(vocab_type)
self._request_listed_files(model_repo, vocab_files)
def download_all_model_files(self, model_repo: str) -> None: def download_model_weights_and_tokenizers(self, model_repo: str) -> None:
# attempt by priority
self.download_model_weights(model_repo)
self.download_model_tokenizers(model_repo)
def download_all_repository_files(self, model_repo: str) -> None:
all_files = self.request.list_remote_files(model_repo) all_files = self.request.list_remote_files(model_repo)
self._request_listed_files(model_repo, all_files) self._request_listed_files(model_repo, all_files)