Check for safetensors files first, and only use PyTorch versions when safetensors aren't available
This commit is contained in:
parent
b8279c82d0
commit
d8c36c91f8
1 changed files with 4 additions and 4 deletions
|
@ -1035,11 +1035,11 @@ def load_some_model(path: Path) -> ModelPlus:
|
|||
'''Load a model of any supported format.'''
|
||||
# Be extra-friendly and accept either a file or a directory:
|
||||
if path.is_dir():
|
||||
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
|
||||
files = [file for glob in globs for file in path.glob(glob)]
|
||||
# Check if it's a set of safetensors files first
|
||||
files = list(path.glob("model-00001-of-*.safetensors"))
|
||||
if not files:
|
||||
# Check if it's a set of safetensors files
|
||||
globs = ["model-00001-of-*.safetensors"]
|
||||
# Try the PyTorch patterns too, with lower priority
|
||||
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
|
||||
files = [file for glob in globs for file in path.glob(glob)]
|
||||
if not files:
|
||||
# Try GGML too, but with lower priority, since if both a non-GGML
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue