Refactor HFubModel and HFHubTokenizer to fix reference issues

This commit is contained in:
teleprint-me 2024-05-25 04:15:15 -04:00
parent fda2319d7b
commit 2ffe6b89c8
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48
2 changed files with 79 additions and 56 deletions

View file

@ -114,56 +114,6 @@ class HFHubRequest(HFHubBase):
return response
class HFHubModel(HFHubBase):
def __init__(
self,
auth_token: None | str,
model_path: None | str | pathlib.Path,
logger: None | logging.Logger
):
super().__init__(model_path, logger)
self._request = HFHubRequest(auth_token, model_path, logger)
@property
def request(self) -> HFHubRequest:
return self._request
def _request_single_file(
self, model_repo: str, file_name: str, file_path: pathlib.Path
) -> bool:
# NOTE: Do not use bare exceptions! They mask issues!
# Allow the exception to occur or explicitly handle it.
try:
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
resolved_url = self.request.resolve_url(model_repo, file_name)
response = self.request.get_response(resolved_url)
self.write_file(response.content, file_path)
self.logger.info(f"Model file successfully saved to {file_path}")
return True
except requests.exceptions.HTTPError as e:
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
return False
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
for file_name in remote_files:
dir_path = self.model_path / model_repo
os.makedirs(dir_path, exist_ok=True)
self._request_single_file(model_repo, file_name, dir_path / file_name)
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
filtered_files = self.list_filtered_remote_files(model_repo, file_type)
self._request_listed_files(model_repo, filtered_files)
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
vocab_files = self.tokenizer.get_vocab_filenames(vocab_type)
self._request_listed_files(model_repo, vocab_files)
def download_all_model_files(self, model_repo: str) -> None:
all_files = self.list_remote_files(model_repo)
self._request_listed_files(model_repo, all_files)
class HFHubTokenizer(HFHubBase):
def __init__(
self,
@ -173,7 +123,7 @@ class HFHubTokenizer(HFHubBase):
super().__init__(model_path, logger)
@staticmethod
def get_vocab_filenames(vocab_type: VocabType) -> tuple[str]:
def list_vocab_files(vocab_type: VocabType) -> tuple[str]:
if vocab_type == VocabType.SPM:
return HF_TOKENIZER_SPM_FILES
# NOTE: WPM and BPE are equivalent
@ -184,7 +134,7 @@ class HFHubTokenizer(HFHubBase):
return VOCAB_TYPE_NAMES.get(vocab_type)
@staticmethod
def get_vocab_enum(vocab_name: str) -> VocabType:
def get_vocab_type(vocab_name: str) -> VocabType:
return {
"SPM": VocabType.SPM,
"BPE": VocabType.BPE,
@ -255,3 +205,58 @@ class HFHubTokenizer(HFHubBase):
for x, y in v.items():
if x not in ["vocab", "merges"]:
self.logger.info(f"{k}:{x}:{json.dumps(y, indent=2)}")
class HFHubModel(HFHubBase):
def __init__(
self,
auth_token: None | str,
model_path: None | str | pathlib.Path,
logger: None | logging.Logger
):
super().__init__(model_path, logger)
self._request = HFHubRequest(auth_token, model_path, logger)
self._tokenizer = HFHubTokenizer(model_path, logger)
@property
def request(self) -> HFHubRequest:
return self._request
@property
def tokenizer(self) -> HFHubTokenizer:
return self._tokenizer
def _request_single_file(
self, model_repo: str, file_name: str, file_path: pathlib.Path
) -> bool:
# NOTE: Do not use bare exceptions! They mask issues!
# Allow the exception to occur or explicitly handle it.
try:
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
resolved_url = self.request.resolve_url(model_repo, file_name)
response = self.request.get_response(resolved_url)
self.write_file(response.content, file_path)
self.logger.info(f"Model file successfully saved to {file_path}")
return True
except requests.exceptions.HTTPError as e:
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
return False
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
for file_name in remote_files:
dir_path = self.model_path / model_repo
os.makedirs(dir_path, exist_ok=True)
self._request_single_file(model_repo, file_name, dir_path / file_name)
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
filtered_files = self.request.list_filtered_remote_files(model_repo, file_type)
self._request_listed_files(model_repo, filtered_files)
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
vocab_files = self.tokenizer.list_vocab_files(vocab_type)
self._request_listed_files(model_repo, vocab_files)
def download_all_model_files(self, model_repo: str) -> None:
all_files = self.request.list_remote_files(model_repo)
self._request_listed_files(model_repo, all_files)

View file

@ -1,5 +1,21 @@
#!/usr/bin/env python3
"""
Tokenizers Vocabulary Notes:
Normalizers:
TODO
Pre-tokenizers:
Byte Level Pre-tokenization uses openai/gpt-2 RegEx from `encoder.py` by default.
There are other Pre-tokenization types, e.g. BERT, which inherits from Byte Level
The defaults for each RegEx are identical in either case.
Pre-Tokenization encompasses identify characters and their types
- A pattern may match a type of "Sequence"
- Letters and Numbers: Alphabetic or Alphanumeric
- Whitespace:
"""
from __future__ import annotations
import argparse
@ -16,7 +32,7 @@ if (
sys.path.insert(0, str(Path(__file__).parent.parent))
from gguf.constants import MODEL_ARCH, MODEL_ARCH_NAMES
from gguf.huggingface_hub import HFVocabRequest
from gguf.huggingface_hub import HFHub, HFTokenizer
logger = logging.getLogger(Path(__file__).stem)
@ -33,7 +49,7 @@ def main():
parser.add_argument(
"-m",
"--model-path",
default="models/",
default="models",
help="The models storage path. Default is 'models/'.",
)
parser.add_argument(
@ -50,9 +66,11 @@ def main():
else:
logging.basicConfig(level=logging.INFO)
vocab_request = HFVocabRequest(args.auth_token, args.model_path, logger)
vocab_type = vocab_request.get_vocab_enum(args.vocab_type)
vocab_request = HFModel(args.auth_token, args.model_path, logger)
vocab_type = HFTokenizer.get_vocab_enum(args.vocab_type)
tokenizer = vocab_request.tokenizer
vocab_request.get_all_vocab_files(args.model_repo, vocab_type)
tokenizer.log_tokenizer_json_info(args.model_repo)
if __name__ == "__main__":