make mode compatiable

This commit is contained in:
isotr0py 2024-11-05 14:01:38 +08:00
parent 6a13722ca5
commit 07ef1a8a04

View file

@ -88,7 +88,7 @@ class GGUFReader:
} }
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
file_mode = "rb" if mode == 'r' else 'rb+' file_mode = "rb+" if mode == 'r+' else 'rb'
self.mode = mode self.mode = mode
self.data = open(path, mode=file_mode) self.data = open(path, mode=file_mode)
self.mmap = np.memmap(self.data, mode = mode) self.mmap = np.memmap(self.data, mode = mode)
@ -147,17 +147,22 @@ class GGUFReader:
return self.tensors[idx] return self.tensors[idx]
def _get( def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, use_mmap: bool = False
) -> npt.NDArray[Any]: ) -> npt.NDArray[Any]:
count = int(count) count = int(count)
itemsize = np.dtype(dtype).itemsize itemsize = np.dtype(dtype).itemsize
end_offs = offset + itemsize * count end_offs = offset + itemsize * count
self.data.seek(end_offs) if self.mode != "r" or use_mmap:
return ( data = (
self.mmap[offset:end_offs] self.mmap[offset:end_offs]
.view(dtype = dtype)[:count] .view(dtype = dtype)[:count]
.newbyteorder(override_order or self.byte_order) .newbyteorder(override_order or self.byte_order)
) )
self.data.seek(end_offs)
else:
self.data.seek(offset)
data = np.frombuffer(self.data.read(itemsize * count), dtype = dtype)
return data
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
if field.name in self.fields: if field.name in self.fields:
@ -170,14 +175,15 @@ class GGUFReader:
self.fields[field.name] = field self.fields[field.name] = field
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
def _get_str(self, offset: int, return_size=False) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: def _get_str(self, offset: int) -> list[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
self.data.seek(offset) self.data.seek(offset)
slen = struct.unpack('<Q', self.data.read(8)) if self.mode != "r":
sdata = struct.unpack('<' + 'B' * slen[0], self.data.read(slen[0])) slen = self._get(offset, np.uint64)
output = (slen, sdata) sdata = self._get(offset + 8, np.uint8, slen[0])
if return_size: else:
output += (8 + slen[0],) slen = np.frombuffer(self.data.read(8), dtype = np.uint64)
return output sdata = np.frombuffer(self.data.read(slen.item()), dtype = np.uint8)
return [slen, sdata]
def _get_field_parts( def _get_field_parts(
self, orig_offs: int, raw_type: int, self, orig_offs: int, raw_type: int,
@ -188,8 +194,8 @@ class GGUFReader:
types.append(gtype) types.append(gtype)
# Handle strings. # Handle strings.
if gtype == GGUFValueType.STRING: if gtype == GGUFValueType.STRING:
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs, return_size=True)) sparts: list[npt.NDArray[Any]] = self._get_str(offs)
size = sparts.pop(-1) size = 8 + sparts[0].item()
return size, sparts, [1], types return size, sparts, [1], types
# Check if it's a simple scalar type. # Check if it's a simple scalar type.
nptype = self.gguf_scalar_to_np.get(gtype) nptype = self.gguf_scalar_to_np.get(gtype)
@ -324,7 +330,7 @@ class GGUFReader:
n_elements = n_elems, n_elements = n_elems,
n_bytes = n_bytes, n_bytes = n_bytes,
data_offset = data_offs, data_offset = data_offs,
data = self._get(data_offs, item_type, item_count).reshape(np_dims), data = self._get(data_offs, item_type, item_count, use_mmap=True).reshape(np_dims),
field = field, field = field,
)) ))
self.tensors = tensors self.tensors = tensors