diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py index 2e7619efb..4667ec829 100644 --- a/gguf-py/gguf/huggingface_hub.py +++ b/gguf-py/gguf/huggingface_hub.py @@ -1,17 +1,24 @@ import logging +import os import pathlib import requests class HuggingFaceHub: - def __init__(self, auth_token: None | str): + def __init__(self, auth_token: None | str, logger: None | logging.Logger): # Set headers if authentication is available if auth_token is None: self._headers = {} else: self._headers = {"Authorization": f"Bearer {auth_token}"} + # Set the logger + if logger is None: + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger("huggingface-hub") + self.logger = logger + # Persist across requests self._session = requests.Session() @@ -34,6 +41,12 @@ class HuggingFaceHub: def base_url(self) -> str: return self._base_url + @staticmethod + def write_file(self, content: bytes, save_path: pathlib.Path) -> None: + with open(save_path, 'wb') as f: + f.write(content) + self.logger.info(f"File {save_path} downloaded successfully") + def resolve_path(self, repo: str, file: str) -> str: return f"{self._base_url}/{repo}/resolve/main/{file}"