From b59c371035735aa551c3760fe6412f7841405874 Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 3 May 2023 23:57:08 +0200 Subject: [PATCH] add support for ByteStorage, relax model glob --- convert.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/convert.py b/convert.py index 7f7ae05fa..245c1915d 100644 --- a/convert.py +++ b/convert.py @@ -40,6 +40,7 @@ class UnquantizedDataType: DT_F16 = UnquantizedDataType('F16') DT_F32 = UnquantizedDataType('F32') +DT_U8 = UnquantizedDataType('U8') DT_I32 = UnquantizedDataType('I32') DT_BF16 = UnquantizedDataType('BF16') @@ -69,6 +70,7 @@ FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DT_F16: np.dtype(np.float16), DT_F32: np.dtype(np.float32), + DT_U8: np.dtype(np.uint8), DT_I32: np.dtype(np.int32), } @@ -702,6 +704,7 @@ class LazyUnpickler(pickle.Unpickler): ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'ByteStorage'): LazyStorageKind(DT_U8), } def find_class(self, module: str, name: str) -> Any: @@ -726,6 +729,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { 'F16': DT_F16, 'F32': DT_F32, + 'U8': DT_U8, 'I32': DT_I32, } @@ -1035,7 +1039,7 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + globs = ["consolidated.00.pth", "pytorch_model*.bin", "*.pt"] files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML