From dbdf6c2b1da8e04d9c0acd9d916f09c0c8c1051b Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Fri, 17 May 2024 20:00:48 -0400 Subject: [PATCH] feat: Add prototype for managing huggingface hub content --- gguf-py/gguf/huggingface_hub.py | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 gguf-py/gguf/huggingface_hub.py diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py new file mode 100644 index 000000000..2e7619efb --- /dev/null +++ b/gguf-py/gguf/huggingface_hub.py @@ -0,0 +1,44 @@ +import logging +import pathlib + +import requests + + +class HuggingFaceHub: + def __init__(self, auth_token: None | str): + # Set headers if authentication is available + if auth_token is None: + self._headers = {} + else: + self._headers = {"Authorization": f"Bearer {auth_token}"} + + # Persist across requests + self._session = requests.Session() + + # This is read-only + self._base_url = "https://huggingface.co" + + @property + def headers(self) -> str: + return self._headers + + @property + def save_path(self) -> pathlib.Path: + return self._save_path + + @property + def session(self) -> requests.Session: + return self._session + + @property + def base_url(self) -> str: + return self._base_url + + def resolve_path(self, repo: str, file: str) -> str: + return f"{self._base_url}/{repo}/resolve/main/{file}" + + def download_file(self, repo: str, file: str): + endpoint = self.resolve_path(repo, file) + response = self._session.get(endpoint, headers=self.headers) + response.raise_for_status() + return response