This commit is contained in:
Isotr0py 2024-12-12 18:35:08 +01:00 committed by GitHub
commit ddb5a0254f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import logging import logging
import os import os
import struct
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Literal, NamedTuple, TypeVar, Union from typing import Any, Literal, NamedTuple, TypeVar, Union
@ -87,11 +88,15 @@ class GGUFReader:
} }
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
self.data = np.memmap(path, mode = mode) file_mode = "rb+" if mode == 'r+' else 'rb'
self.mode = mode
self.data = open(path, mode=file_mode)
self.mmap = np.memmap(self.data, mode = mode)
offs = 0 offs = 0
# Check for GGUF magic # Check for GGUF magic
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: self.data.seek(offs)
if struct.unpack("<I", self.data.read(4))[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid') raise ValueError('GGUF magic invalid')
offs += 4 offs += 4
@ -129,6 +134,9 @@ class GGUFReader:
self.data_offset = offs self.data_offset = offs
self._build_tensors(offs, tensors_fields) self._build_tensors(offs, tensors_fields)
def __del__(self) -> None:
self.data.close()
_DT = TypeVar('_DT', bound = npt.DTypeLike) _DT = TypeVar('_DT', bound = npt.DTypeLike)
# Fetch a key/value metadata field by key. # Fetch a key/value metadata field by key.
@ -140,16 +148,24 @@ class GGUFReader:
return self.tensors[idx] return self.tensors[idx]
def _get( def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, use_mmap: bool = False
) -> npt.NDArray[Any]: ) -> npt.NDArray[Any]:
count = int(count) count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize) dtype = np.dtype(dtype)
itemsize = dtype.itemsize
end_offs = offset + itemsize * count end_offs = offset + itemsize * count
return ( if self.mode != "r" or use_mmap:
self.data[offset:end_offs] data = (
self.mmap[offset:end_offs]
.view(dtype = dtype)[:count] .view(dtype = dtype)[:count]
.newbyteorder(override_order or self.byte_order) .newbyteorder(override_order or self.byte_order)
) )
self.data.seek(end_offs)
else:
self.data.seek(offset)
dtype = dtype.newbyteorder(override_order or self.byte_order)
data = np.frombuffer(self.data.read(itemsize * count), dtype = dtype)
return data
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
if field.name in self.fields: if field.name in self.fields:
@ -162,9 +178,18 @@ class GGUFReader:
self.fields[field.name] = field self.fields[field.name] = field
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: def _get_str(self, offset: int) -> list[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
if self.mode != "r":
slen = self._get(offset, np.uint64) slen = self._get(offset, np.uint64)
return slen, self._get(offset + 8, np.uint8, slen[0]) sdata = self._get(offset + 8, np.uint8, slen.item())
else:
# This is faster to return a read-only str structure with less seek calling.
self.data.seek(offset)
u64 = np.dtype(np.uint64).newbyteorder(self.byte_order)
u8 = np.dtype(np.uint8).newbyteorder(self.byte_order)
slen = np.frombuffer(self.data.read(8), dtype=u64)
sdata = np.frombuffer(self.data.read(slen.item()), dtype=u8)
return [slen, sdata]
def _get_field_parts( def _get_field_parts(
self, orig_offs: int, raw_type: int, self, orig_offs: int, raw_type: int,
@ -175,8 +200,8 @@ class GGUFReader:
types.append(gtype) types.append(gtype)
# Handle strings. # Handle strings.
if gtype == GGUFValueType.STRING: if gtype == GGUFValueType.STRING:
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs)) sparts: list[npt.NDArray[Any]] = self._get_str(offs)
size = sum(int(part.nbytes) for part in sparts) size = 8 + sparts[0].item()
return size, sparts, [1], types return size, sparts, [1], types
# Check if it's a simple scalar type. # Check if it's a simple scalar type.
nptype = self.gguf_scalar_to_np.get(gtype) nptype = self.gguf_scalar_to_np.get(gtype)
@ -186,9 +211,9 @@ class GGUFReader:
# Handle arrays. # Handle arrays.
if gtype == GGUFValueType.ARRAY: if gtype == GGUFValueType.ARRAY:
raw_itype = self._get(offs, np.uint32) raw_itype = self._get(offs, np.uint32)
offs += int(raw_itype.nbytes) offs = self.data.tell()
alen = self._get(offs, np.uint64) alen = self._get(offs, np.uint64)
offs += int(alen.nbytes) offs = self.data.tell()
aparts: list[npt.NDArray[Any]] = [raw_itype, alen] aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
data_idxs: list[int] = [] data_idxs: list[int] = []
for idx in range(alen[0]): for idx in range(alen[0]):
@ -208,23 +233,23 @@ class GGUFReader:
# Get Tensor Name # Get Tensor Name
name_len, name_data = self._get_str(offs) name_len, name_data = self._get_str(offs)
offs += int(name_len.nbytes + name_data.nbytes) offs = self.data.tell()
# Get Tensor Dimensions Count # Get Tensor Dimensions Count
n_dims = self._get(offs, np.uint32) n_dims = self._get(offs, np.uint32)
offs += int(n_dims.nbytes) offs = self.data.tell()
# Get Tensor Dimension Array # Get Tensor Dimension Array
dims = self._get(offs, np.uint64, n_dims[0]) dims = self._get(offs, np.uint64, n_dims[0])
offs += int(dims.nbytes) offs = self.data.tell()
# Get Tensor Encoding Scheme Type # Get Tensor Encoding Scheme Type
raw_dtype = self._get(offs, np.uint32) raw_dtype = self._get(offs, np.uint32)
offs += int(raw_dtype.nbytes) offs = self.data.tell()
# Get Tensor Offset # Get Tensor Offset
offset_tensor = self._get(offs, np.uint64) offset_tensor = self._get(offs, np.uint64)
offs += int(offset_tensor.nbytes) offs = self.data.tell()
return ReaderField( return ReaderField(
orig_offs, orig_offs,
@ -237,9 +262,9 @@ class GGUFReader:
for _ in range(count): for _ in range(count):
orig_offs = offs orig_offs = offs
kv_klen, kv_kdata = self._get_str(offs) kv_klen, kv_kdata = self._get_str(offs)
offs += int(kv_klen.nbytes + kv_kdata.nbytes) offs = self.data.tell()
raw_kv_type = self._get(offs, np.uint32) raw_kv_type = self._get(offs, np.uint32)
offs += int(raw_kv_type.nbytes) offs = self.data.tell()
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
idxs_offs = len(parts) idxs_offs = len(parts)
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
@ -258,7 +283,7 @@ class GGUFReader:
tensor_fields = [] tensor_fields = []
for _ in range(count): for _ in range(count):
field = self._get_tensor_info_field(offs) field = self._get_tensor_info_field(offs)
offs += sum(int(part.nbytes) for part in field.parts) offs = self.data.tell()
tensor_fields.append(field) tensor_fields.append(field)
return offs, tensor_fields return offs, tensor_fields
@ -311,7 +336,7 @@ class GGUFReader:
n_elements = n_elems, n_elements = n_elems,
n_bytes = n_bytes, n_bytes = n_bytes,
data_offset = data_offs, data_offset = data_offs,
data = self._get(data_offs, item_type, item_count).reshape(np_dims), data = self._get(data_offs, item_type, item_count, use_mmap=True).reshape(np_dims),
field = field, field = field,
)) ))
self.tensors = tensors self.tensors = tensors