From 07ef1a8a042ff7cdec89efc500c16f911b0c1eeb Mon Sep 17 00:00:00 2001 From: isotr0py <2037008807@qq.com> Date: Tue, 5 Nov 2024 14:01:38 +0800 Subject: [PATCH] make mode compatiable --- gguf-py/gguf/gguf_reader.py | 42 +++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index ee94ae056..eb1068b7a 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -88,7 +88,7 @@ class GGUFReader: } def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): - file_mode = "rb" if mode == 'r' else 'rb+' + file_mode = "rb+" if mode == 'r+' else 'rb' self.mode = mode self.data = open(path, mode=file_mode) self.mmap = np.memmap(self.data, mode = mode) @@ -147,17 +147,22 @@ class GGUFReader: return self.tensors[idx] def _get( - self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, + self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, use_mmap: bool = False ) -> npt.NDArray[Any]: count = int(count) itemsize = np.dtype(dtype).itemsize end_offs = offset + itemsize * count - self.data.seek(end_offs) - return ( - self.mmap[offset:end_offs] - .view(dtype = dtype)[:count] - .newbyteorder(override_order or self.byte_order) - ) + if self.mode != "r" or use_mmap: + data = ( + self.mmap[offset:end_offs] + .view(dtype = dtype)[:count] + .newbyteorder(override_order or self.byte_order) + ) + self.data.seek(end_offs) + else: + self.data.seek(offset) + data = np.frombuffer(self.data.read(itemsize * count), dtype = dtype) + return data def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: @@ -170,14 +175,15 @@ class GGUFReader: self.fields[field.name] = field return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) - def _get_str(self, offset: int, return_size=False) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: + def _get_str(self, offset: int) -> list[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: self.data.seek(offset) - slen = struct.unpack('