add support for ByteStorage, relax model glob

This commit is contained in:
alex 2023-05-03 23:57:08 +02:00
parent 6daa09d879
commit b59c371035

View file

@ -40,6 +40,7 @@ class UnquantizedDataType:
DT_F16 = UnquantizedDataType('F16') DT_F16 = UnquantizedDataType('F16')
DT_F32 = UnquantizedDataType('F32') DT_F32 = UnquantizedDataType('F32')
DT_U8 = UnquantizedDataType('U8')
DT_I32 = UnquantizedDataType('I32') DT_I32 = UnquantizedDataType('I32')
DT_BF16 = UnquantizedDataType('BF16') DT_BF16 = UnquantizedDataType('BF16')
@ -69,6 +70,7 @@ FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_F16: np.dtype(np.float16), DT_F16: np.dtype(np.float16),
DT_F32: np.dtype(np.float32), DT_F32: np.dtype(np.float32),
DT_U8: np.dtype(np.uint8),
DT_I32: np.dtype(np.int32), DT_I32: np.dtype(np.int32),
} }
@ -702,6 +704,7 @@ class LazyUnpickler(pickle.Unpickler):
('torch', 'HalfStorage'): LazyStorageKind(DT_F16), ('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
('torch', 'FloatStorage'): LazyStorageKind(DT_F32), ('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
('torch', 'IntStorage'): LazyStorageKind(DT_I32), ('torch', 'IntStorage'): LazyStorageKind(DT_I32),
('torch', 'ByteStorage'): LazyStorageKind(DT_U8),
} }
def find_class(self, module: str, name: str) -> Any: def find_class(self, module: str, name: str) -> Any:
@ -726,6 +729,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'F16': DT_F16, 'F16': DT_F16,
'F32': DT_F32, 'F32': DT_F32,
'U8': DT_U8,
'I32': DT_I32, 'I32': DT_I32,
} }
@ -1035,7 +1039,7 @@ def load_some_model(path: Path) -> ModelPlus:
'''Load a model of any supported format.''' '''Load a model of any supported format.'''
# Be extra-friendly and accept either a file or a directory: # Be extra-friendly and accept either a file or a directory:
if path.is_dir(): if path.is_dir():
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] globs = ["consolidated.00.pth", "pytorch_model*.bin", "*.pt"]
files = [file for glob in globs for file in path.glob(glob)] files = [file for glob in globs for file in path.glob(glob)]
if not files: if not files:
# Try GGML too, but with lower priority, since if both a non-GGML # Try GGML too, but with lower priority, since if both a non-GGML