Check for single-file safetensors model

This commit is contained in:
afrideva 2023-11-11 20:20:07 -08:00 committed by GitHub
parent 4dce910cbc
commit c7bae1e125
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1036,7 +1036,8 @@ def load_some_model(path: Path) -> ModelPlus:
# Be extra-friendly and accept either a file or a directory:
if path.is_dir():
# Check if it's a set of safetensors files first
files = list(path.glob("model-00001-of-*.safetensors"))
globs = ["model-00001-of-*.safetensors", "model.safetensors"]
files = [file for glob in globs for file in path.glob(glob)]
if not files:
# Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]