refactor: Apply fix for file path references

This commit is contained in:
teleprint-me 2024-05-23 22:59:16 -04:00
parent c91dcdf2a4
commit e62e09bbb1
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -53,10 +53,10 @@ class HFHubRequest:
def base_url(self) -> str: def base_url(self) -> str:
return self._base_url return self._base_url
def write_file(self, content: bytes, filepath: pathlib.Path) -> None: def write_file(self, content: bytes, file_path: pathlib.Path) -> None:
with open(filepath, 'wb') as f: with open(file_path, 'wb') as f:
f.write(content) f.write(content)
self.logger.info(f"Wrote {len(content)} bytes to {filepath} successfully") self.logger.info(f"Wrote {len(content)} bytes to {file_path} successfully")
def resolve_url(self, repo: str, filename: str) -> str: def resolve_url(self, repo: str, filename: str) -> str:
return f"{self._base_url}/{repo}/resolve/main/{filename}" return f"{self._base_url}/{repo}/resolve/main/{filename}"
@ -149,7 +149,10 @@ class HFVocabRequest(HFHubBase):
def get_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None: def get_all_vocab_files(self, model_repo: str, vocab_type: VocabType) -> None:
vocab_list = self.get_vocab_filenames(vocab_type) vocab_list = self.get_vocab_filenames(vocab_type)
for vocab_file in vocab_list: for vocab_file in vocab_list:
self.get_vocab_file(model_repo, vocab_file, self.model_path) dir_path = self.model_path / model_repo
file_path = dir_path / vocab_file
os.makedirs(dir_path, exist_ok=True)
self.get_vocab_file(model_repo, vocab_file, file_path)
def get_normalizer(self) -> None | dict[str, object]: def get_normalizer(self) -> None | dict[str, object]:
with open(self.tokenizer_path, mode="r") as file: with open(self.tokenizer_path, mode="r") as file:
@ -165,10 +168,10 @@ class HFVocabRequest(HFHubBase):
checksums = [] checksums = []
for model in self.models: for model in self.models:
mapping = {} mapping = {}
filepath = f"{self.model_path}/{model['repo']}" file_path = f"{self.model_path}/{model['repo']}"
try: try:
tokenizer = AutoTokenizer.from_pretrained(filepath, trust_remote=True) tokenizer = AutoTokenizer.from_pretrained(file_path, trust_remote=True)
except OSError as e: except OSError as e:
self.logger.error(f"Failed to hash tokenizer {model['repo']}: {e}") self.logger.error(f"Failed to hash tokenizer {model['repo']}: {e}")
continue continue