Refactor HFubModel and HFHubTokenizer to fix reference issues
This commit is contained in:
parent
fda2319d7b
commit
2ffe6b89c8
2 changed files with 79 additions and 56 deletions
|
@ -114,56 +114,6 @@ class HFHubRequest(HFHubBase):
|
|||
return response
|
||||
|
||||
|
||||
class HFHubModel(HFHubBase):
|
||||
def __init__(
|
||||
self,
|
||||
auth_token: None | str,
|
||||
model_path: None | str | pathlib.Path,
|
||||
logger: None | logging.Logger
|
||||
):
|
||||
super().__init__(model_path, logger)
|
||||
|
||||
self._request = HFHubRequest(auth_token, model_path, logger)
|
||||
|
||||
@property
|
||||
def request(self) -> HFHubRequest:
|
||||
return self._request
|
||||
|
||||
def _request_single_file(
|
||||
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
||||
) -> bool:
|
||||
# NOTE: Do not use bare exceptions! They mask issues!
|
||||
# Allow the exception to occur or explicitly handle it.
|
||||
try:
|
||||
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
||||
resolved_url = self.request.resolve_url(model_repo, file_name)
|
||||
response = self.request.get_response(resolved_url)
|
||||
self.write_file(response.content, file_path)
|
||||
self.logger.info(f"Model file successfully saved to {file_path}")
|
||||
return True
|
||||
except requests.exceptions.HTTPError as e:
|
||||
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
||||
return False
|
||||
|
||||
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
|
||||
for file_name in remote_files:
|
||||
dir_path = self.model_path / model_repo
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
||||
|
||||
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
|
||||
filtered_files = self.list_filtered_remote_files(model_repo, file_type)
|
||||
self._request_listed_files(model_repo, filtered_files)
|
||||
|
||||
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
|
||||
vocab_files = self.tokenizer.get_vocab_filenames(vocab_type)
|
||||
self._request_listed_files(model_repo, vocab_files)
|
||||
|
||||
def download_all_model_files(self, model_repo: str) -> None:
|
||||
all_files = self.list_remote_files(model_repo)
|
||||
self._request_listed_files(model_repo, all_files)
|
||||
|
||||
|
||||
class HFHubTokenizer(HFHubBase):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -173,7 +123,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
super().__init__(model_path, logger)
|
||||
|
||||
@staticmethod
|
||||
def get_vocab_filenames(vocab_type: VocabType) -> tuple[str]:
|
||||
def list_vocab_files(vocab_type: VocabType) -> tuple[str]:
|
||||
if vocab_type == VocabType.SPM:
|
||||
return HF_TOKENIZER_SPM_FILES
|
||||
# NOTE: WPM and BPE are equivalent
|
||||
|
@ -184,7 +134,7 @@ class HFHubTokenizer(HFHubBase):
|
|||
return VOCAB_TYPE_NAMES.get(vocab_type)
|
||||
|
||||
@staticmethod
|
||||
def get_vocab_enum(vocab_name: str) -> VocabType:
|
||||
def get_vocab_type(vocab_name: str) -> VocabType:
|
||||
return {
|
||||
"SPM": VocabType.SPM,
|
||||
"BPE": VocabType.BPE,
|
||||
|
@ -255,3 +205,58 @@ class HFHubTokenizer(HFHubBase):
|
|||
for x, y in v.items():
|
||||
if x not in ["vocab", "merges"]:
|
||||
self.logger.info(f"{k}:{x}:{json.dumps(y, indent=2)}")
|
||||
|
||||
|
||||
class HFHubModel(HFHubBase):
|
||||
def __init__(
|
||||
self,
|
||||
auth_token: None | str,
|
||||
model_path: None | str | pathlib.Path,
|
||||
logger: None | logging.Logger
|
||||
):
|
||||
super().__init__(model_path, logger)
|
||||
|
||||
self._request = HFHubRequest(auth_token, model_path, logger)
|
||||
self._tokenizer = HFHubTokenizer(model_path, logger)
|
||||
|
||||
@property
|
||||
def request(self) -> HFHubRequest:
|
||||
return self._request
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> HFHubTokenizer:
|
||||
return self._tokenizer
|
||||
|
||||
def _request_single_file(
|
||||
self, model_repo: str, file_name: str, file_path: pathlib.Path
|
||||
) -> bool:
|
||||
# NOTE: Do not use bare exceptions! They mask issues!
|
||||
# Allow the exception to occur or explicitly handle it.
|
||||
try:
|
||||
self.logger.info(f"Downloading '{file_name}' from {model_repo}")
|
||||
resolved_url = self.request.resolve_url(model_repo, file_name)
|
||||
response = self.request.get_response(resolved_url)
|
||||
self.write_file(response.content, file_path)
|
||||
self.logger.info(f"Model file successfully saved to {file_path}")
|
||||
return True
|
||||
except requests.exceptions.HTTPError as e:
|
||||
self.logger.error(f"Error while downloading '{file_name}': {str(e)}")
|
||||
return False
|
||||
|
||||
def _request_listed_files(self, model_repo: str, remote_files: list[str]) -> None:
|
||||
for file_name in remote_files:
|
||||
dir_path = self.model_path / model_repo
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
self._request_single_file(model_repo, file_name, dir_path / file_name)
|
||||
|
||||
def download_model_files(self, model_repo: str, file_type: ModelFileType) -> None:
|
||||
filtered_files = self.request.list_filtered_remote_files(model_repo, file_type)
|
||||
self._request_listed_files(model_repo, filtered_files)
|
||||
|
||||
def download_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
|
||||
vocab_files = self.tokenizer.list_vocab_files(vocab_type)
|
||||
self._request_listed_files(model_repo, vocab_files)
|
||||
|
||||
def download_all_model_files(self, model_repo: str) -> None:
|
||||
all_files = self.request.list_remote_files(model_repo)
|
||||
self._request_listed_files(model_repo, all_files)
|
||||
|
|
|
@ -1,5 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tokenizers Vocabulary Notes:
|
||||
|
||||
Normalizers:
|
||||
TODO
|
||||
|
||||
Pre-tokenizers:
|
||||
|
||||
Byte Level Pre-tokenization uses openai/gpt-2 RegEx from `encoder.py` by default.
|
||||
There are other Pre-tokenization types, e.g. BERT, which inherits from Byte Level
|
||||
The defaults for each RegEx are identical in either case.
|
||||
|
||||
Pre-Tokenization encompasses identify characters and their types
|
||||
- A pattern may match a type of "Sequence"
|
||||
- Letters and Numbers: Alphabetic or Alphanumeric
|
||||
- Whitespace:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
@ -16,7 +32,7 @@ if (
|
|||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf.constants import MODEL_ARCH, MODEL_ARCH_NAMES
|
||||
from gguf.huggingface_hub import HFVocabRequest
|
||||
from gguf.huggingface_hub import HFHub, HFTokenizer
|
||||
|
||||
logger = logging.getLogger(Path(__file__).stem)
|
||||
|
||||
|
@ -33,7 +49,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-path",
|
||||
default="models/",
|
||||
default="models",
|
||||
help="The models storage path. Default is 'models/'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -50,9 +66,11 @@ def main():
|
|||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
vocab_request = HFVocabRequest(args.auth_token, args.model_path, logger)
|
||||
vocab_type = vocab_request.get_vocab_enum(args.vocab_type)
|
||||
vocab_request = HFModel(args.auth_token, args.model_path, logger)
|
||||
vocab_type = HFTokenizer.get_vocab_enum(args.vocab_type)
|
||||
tokenizer = vocab_request.tokenizer
|
||||
vocab_request.get_all_vocab_files(args.model_repo, vocab_type)
|
||||
tokenizer.log_tokenizer_json_info(args.model_repo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue