Check for safetensors files first, and only use PyTorch versions when safetensors aren't available

This commit is contained in:
ubik2 2023-05-08 00:56:12 -07:00 committed by GitHub
parent b8279c82d0
commit d8c36c91f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1035,11 +1035,11 @@ def load_some_model(path: Path) -> ModelPlus:
'''Load a model of any supported format.'''
# Be extra-friendly and accept either a file or a directory:
if path.is_dir():
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
files = [file for glob in globs for file in path.glob(glob)]
# Check if it's a set of safetensors files first
files = list(path.glob("model-00001-of-*.safetensors"))
if not files:
# Check if it's a set of safetensors files
globs = ["model-00001-of-*.safetensors"]
# Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
files = [file for glob in globs for file in path.glob(glob)]
if not files:
# Try GGML too, but with lower priority, since if both a non-GGML