chore: Fix model path references

This commit is contained in:
teleprint-me 2024-05-18 19:20:19 -04:00
parent b6f70b8a0e
commit 006bb60d27
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -174,12 +174,12 @@ class HFVocabRequest(HFHubBase):
except requests.exceptions.HTTPError as e:
self.logger.error(f"Failed to download tokenizer {model['name']}: {e}")
def download_model(self) -> None:
def download_models(self) -> None:
for model in self.models:
os.makedirs(f"{self.local_path}/{model['repo']}", exist_ok=True)
os.makedirs(f"{self.model_path}/{model['repo']}", exist_ok=True)
filenames = self.resolve_filenames(model['tokt'])
for filename in filenames:
filepath = pathlib.Path(f"{self.local_path}/{model['repo']}/{filename}")
filepath = pathlib.Path(f"{self.model_path}/{model['repo']}/{filename}")
if filepath.is_file():
self.logger.info(f"skipped pre-existing tokenizer {model['repo']} in {filepath}")
continue
@ -189,19 +189,19 @@ class HFVocabRequest(HFHubBase):
checksums = []
for model in self.models:
mapping = {}
filepath = f"{self.local_path}/{model['repo']}"
filepath = f"{self.model_path}/{model['repo']}"
tokenizer = AutoTokenizer.from_pretrained(filepath, trust_remote=True)
mapping.update(model)
mapping['checksum'] = sha256(str(tokenizer.vocab).encode()).hexdigest()
self.logger.info(f"Hashed {mapping['repo']} as {mapping['checksum']}")
checksums.append(mapping)
with open(f"{self.local_path.parent}/checksums.json") as file:
with open(f"{self.model_path.parent}/checksums.json") as file:
json.dump(checksums, file)
def log_pre_tokenizer_info(self) -> None:
for model in self.models:
with open(f"{self.local_path}/{model['repo']}/tokenizer.json", "r", encoding="utf-8") as f:
with open(f"{self.model_path}/{model['repo']}/tokenizer.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
self.logger.info(f"normalizer: {json.dumps(cfg['normalizer'], indent=4)}")
self.logger.info(f"pre_tokenizer: {json.dumps(cfg['pre_tokenizer'], indent=4)}")
@ -221,3 +221,7 @@ class HFModelRequest(HFHubBase):
logger: None | logging.Logger
):
super().__init__(model_path, auth_token, logger)
@property
def model_type(self) -> ModelType:
return ModelType