refactor gguf reader

This commit is contained in:
isotr0py 2024-10-31 20:19:24 +08:00
parent fc83a9e584
commit 5ce2dbcf38

View file

@ -87,7 +87,8 @@ class GGUFReader:
} }
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
self.data = np.memmap(path, mode = mode) self.data = open(path, mode="rb")
self.mmap = np.memmap(path, mode = mode)
offs = 0 offs = 0
# Check for GGUF magic # Check for GGUF magic
@ -127,7 +128,8 @@ class GGUFReader:
if padding != 0: if padding != 0:
offs += self.alignment - padding offs += self.alignment - padding
self.data_offset = offs self.data_offset = offs
self._build_tensors(offs, tensors_fields) # self._build_tensors(offs, tensors_fields)
self.data.close()
_DT = TypeVar('_DT', bound = npt.DTypeLike) _DT = TypeVar('_DT', bound = npt.DTypeLike)
@ -140,13 +142,19 @@ class GGUFReader:
return self.tensors[idx] return self.tensors[idx]
def _get( def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, lazy: bool = False,
) -> npt.NDArray[Any]: ) -> npt.NDArray[Any]:
count = int(count) count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize) itemsize = np.dtype(dtype).itemsize
end_offs = offset + itemsize * count if not lazy:
self.data.seek(offset)
return ( return (
self.data[offset:end_offs] np.frombuffer(self.data.read(itemsize * count), dtype = dtype, count = count)
.newbyteorder(override_order or self.byte_order)
)
else:
return (
self.mmap[offset:offset + itemsize * count]
.view(dtype = dtype)[:count] .view(dtype = dtype)[:count]
.newbyteorder(override_order or self.byte_order) .newbyteorder(override_order or self.byte_order)
) )
@ -311,7 +319,7 @@ class GGUFReader:
n_elements = n_elems, n_elements = n_elems,
n_bytes = n_bytes, n_bytes = n_bytes,
data_offset = data_offs, data_offset = data_offs,
data = self._get(data_offs, item_type, item_count).reshape(np_dims), data = self._get(data_offs, item_type, item_count, lazy=True).reshape(np_dims),
field = field, field = field,
)) ))
self.tensors = tensors self.tensors = tensors