add support for ByteStorage, relax model glob
This commit is contained in:
parent
6daa09d879
commit
b59c371035
1 changed files with 5 additions and 1 deletions
|
@ -40,6 +40,7 @@ class UnquantizedDataType:
|
||||||
|
|
||||||
DT_F16 = UnquantizedDataType('F16')
|
DT_F16 = UnquantizedDataType('F16')
|
||||||
DT_F32 = UnquantizedDataType('F32')
|
DT_F32 = UnquantizedDataType('F32')
|
||||||
|
DT_U8 = UnquantizedDataType('U8')
|
||||||
DT_I32 = UnquantizedDataType('I32')
|
DT_I32 = UnquantizedDataType('I32')
|
||||||
DT_BF16 = UnquantizedDataType('BF16')
|
DT_BF16 = UnquantizedDataType('BF16')
|
||||||
|
|
||||||
|
@ -69,6 +70,7 @@ FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
|
||||||
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
||||||
DT_F16: np.dtype(np.float16),
|
DT_F16: np.dtype(np.float16),
|
||||||
DT_F32: np.dtype(np.float32),
|
DT_F32: np.dtype(np.float32),
|
||||||
|
DT_U8: np.dtype(np.uint8),
|
||||||
DT_I32: np.dtype(np.int32),
|
DT_I32: np.dtype(np.int32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -702,6 +704,7 @@ class LazyUnpickler(pickle.Unpickler):
|
||||||
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
|
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
|
||||||
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
|
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
|
||||||
('torch', 'IntStorage'): LazyStorageKind(DT_I32),
|
('torch', 'IntStorage'): LazyStorageKind(DT_I32),
|
||||||
|
('torch', 'ByteStorage'): LazyStorageKind(DT_U8),
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_class(self, module: str, name: str) -> Any:
|
def find_class(self, module: str, name: str) -> Any:
|
||||||
|
@ -726,6 +729,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
|
||||||
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
||||||
'F16': DT_F16,
|
'F16': DT_F16,
|
||||||
'F32': DT_F32,
|
'F32': DT_F32,
|
||||||
|
'U8': DT_U8,
|
||||||
'I32': DT_I32,
|
'I32': DT_I32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1035,7 +1039,7 @@ def load_some_model(path: Path) -> ModelPlus:
|
||||||
'''Load a model of any supported format.'''
|
'''Load a model of any supported format.'''
|
||||||
# Be extra-friendly and accept either a file or a directory:
|
# Be extra-friendly and accept either a file or a directory:
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
|
globs = ["consolidated.00.pth", "pytorch_model*.bin", "*.pt"]
|
||||||
files = [file for glob in globs for file in path.glob(glob)]
|
files = [file for glob in globs for file in path.glob(glob)]
|
||||||
if not files:
|
if not files:
|
||||||
# Try GGML too, but with lower priority, since if both a non-GGML
|
# Try GGML too, but with lower priority, since if both a non-GGML
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue