Merge 94d814c559
into 9fdb124304
This commit is contained in:
commit
ddb5a0254f
1 changed files with 50 additions and 25 deletions
|
@ -6,6 +6,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import struct
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
||||||
|
|
||||||
|
@ -87,11 +88,15 @@ class GGUFReader:
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
|
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
|
||||||
self.data = np.memmap(path, mode = mode)
|
file_mode = "rb+" if mode == 'r+' else 'rb'
|
||||||
|
self.mode = mode
|
||||||
|
self.data = open(path, mode=file_mode)
|
||||||
|
self.mmap = np.memmap(self.data, mode = mode)
|
||||||
offs = 0
|
offs = 0
|
||||||
|
|
||||||
# Check for GGUF magic
|
# Check for GGUF magic
|
||||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
self.data.seek(offs)
|
||||||
|
if struct.unpack("<I", self.data.read(4))[0] != GGUF_MAGIC:
|
||||||
raise ValueError('GGUF magic invalid')
|
raise ValueError('GGUF magic invalid')
|
||||||
offs += 4
|
offs += 4
|
||||||
|
|
||||||
|
@ -129,6 +134,9 @@ class GGUFReader:
|
||||||
self.data_offset = offs
|
self.data_offset = offs
|
||||||
self._build_tensors(offs, tensors_fields)
|
self._build_tensors(offs, tensors_fields)
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
self.data.close()
|
||||||
|
|
||||||
_DT = TypeVar('_DT', bound = npt.DTypeLike)
|
_DT = TypeVar('_DT', bound = npt.DTypeLike)
|
||||||
|
|
||||||
# Fetch a key/value metadata field by key.
|
# Fetch a key/value metadata field by key.
|
||||||
|
@ -140,16 +148,24 @@ class GGUFReader:
|
||||||
return self.tensors[idx]
|
return self.tensors[idx]
|
||||||
|
|
||||||
def _get(
|
def _get(
|
||||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
|
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, use_mmap: bool = False
|
||||||
) -> npt.NDArray[Any]:
|
) -> npt.NDArray[Any]:
|
||||||
count = int(count)
|
count = int(count)
|
||||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
dtype = np.dtype(dtype)
|
||||||
|
itemsize = dtype.itemsize
|
||||||
end_offs = offset + itemsize * count
|
end_offs = offset + itemsize * count
|
||||||
return (
|
if self.mode != "r" or use_mmap:
|
||||||
self.data[offset:end_offs]
|
data = (
|
||||||
|
self.mmap[offset:end_offs]
|
||||||
.view(dtype = dtype)[:count]
|
.view(dtype = dtype)[:count]
|
||||||
.newbyteorder(override_order or self.byte_order)
|
.newbyteorder(override_order or self.byte_order)
|
||||||
)
|
)
|
||||||
|
self.data.seek(end_offs)
|
||||||
|
else:
|
||||||
|
self.data.seek(offset)
|
||||||
|
dtype = dtype.newbyteorder(override_order or self.byte_order)
|
||||||
|
data = np.frombuffer(self.data.read(itemsize * count), dtype = dtype)
|
||||||
|
return data
|
||||||
|
|
||||||
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
||||||
if field.name in self.fields:
|
if field.name in self.fields:
|
||||||
|
@ -162,9 +178,18 @@ class GGUFReader:
|
||||||
self.fields[field.name] = field
|
self.fields[field.name] = field
|
||||||
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
|
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
|
||||||
|
|
||||||
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
|
def _get_str(self, offset: int) -> list[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
|
||||||
|
if self.mode != "r":
|
||||||
slen = self._get(offset, np.uint64)
|
slen = self._get(offset, np.uint64)
|
||||||
return slen, self._get(offset + 8, np.uint8, slen[0])
|
sdata = self._get(offset + 8, np.uint8, slen.item())
|
||||||
|
else:
|
||||||
|
# This is faster to return a read-only str structure with less seek calling.
|
||||||
|
self.data.seek(offset)
|
||||||
|
u64 = np.dtype(np.uint64).newbyteorder(self.byte_order)
|
||||||
|
u8 = np.dtype(np.uint8).newbyteorder(self.byte_order)
|
||||||
|
slen = np.frombuffer(self.data.read(8), dtype=u64)
|
||||||
|
sdata = np.frombuffer(self.data.read(slen.item()), dtype=u8)
|
||||||
|
return [slen, sdata]
|
||||||
|
|
||||||
def _get_field_parts(
|
def _get_field_parts(
|
||||||
self, orig_offs: int, raw_type: int,
|
self, orig_offs: int, raw_type: int,
|
||||||
|
@ -175,8 +200,8 @@ class GGUFReader:
|
||||||
types.append(gtype)
|
types.append(gtype)
|
||||||
# Handle strings.
|
# Handle strings.
|
||||||
if gtype == GGUFValueType.STRING:
|
if gtype == GGUFValueType.STRING:
|
||||||
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
|
sparts: list[npt.NDArray[Any]] = self._get_str(offs)
|
||||||
size = sum(int(part.nbytes) for part in sparts)
|
size = 8 + sparts[0].item()
|
||||||
return size, sparts, [1], types
|
return size, sparts, [1], types
|
||||||
# Check if it's a simple scalar type.
|
# Check if it's a simple scalar type.
|
||||||
nptype = self.gguf_scalar_to_np.get(gtype)
|
nptype = self.gguf_scalar_to_np.get(gtype)
|
||||||
|
@ -186,9 +211,9 @@ class GGUFReader:
|
||||||
# Handle arrays.
|
# Handle arrays.
|
||||||
if gtype == GGUFValueType.ARRAY:
|
if gtype == GGUFValueType.ARRAY:
|
||||||
raw_itype = self._get(offs, np.uint32)
|
raw_itype = self._get(offs, np.uint32)
|
||||||
offs += int(raw_itype.nbytes)
|
offs = self.data.tell()
|
||||||
alen = self._get(offs, np.uint64)
|
alen = self._get(offs, np.uint64)
|
||||||
offs += int(alen.nbytes)
|
offs = self.data.tell()
|
||||||
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
||||||
data_idxs: list[int] = []
|
data_idxs: list[int] = []
|
||||||
for idx in range(alen[0]):
|
for idx in range(alen[0]):
|
||||||
|
@ -208,23 +233,23 @@ class GGUFReader:
|
||||||
|
|
||||||
# Get Tensor Name
|
# Get Tensor Name
|
||||||
name_len, name_data = self._get_str(offs)
|
name_len, name_data = self._get_str(offs)
|
||||||
offs += int(name_len.nbytes + name_data.nbytes)
|
offs = self.data.tell()
|
||||||
|
|
||||||
# Get Tensor Dimensions Count
|
# Get Tensor Dimensions Count
|
||||||
n_dims = self._get(offs, np.uint32)
|
n_dims = self._get(offs, np.uint32)
|
||||||
offs += int(n_dims.nbytes)
|
offs = self.data.tell()
|
||||||
|
|
||||||
# Get Tensor Dimension Array
|
# Get Tensor Dimension Array
|
||||||
dims = self._get(offs, np.uint64, n_dims[0])
|
dims = self._get(offs, np.uint64, n_dims[0])
|
||||||
offs += int(dims.nbytes)
|
offs = self.data.tell()
|
||||||
|
|
||||||
# Get Tensor Encoding Scheme Type
|
# Get Tensor Encoding Scheme Type
|
||||||
raw_dtype = self._get(offs, np.uint32)
|
raw_dtype = self._get(offs, np.uint32)
|
||||||
offs += int(raw_dtype.nbytes)
|
offs = self.data.tell()
|
||||||
|
|
||||||
# Get Tensor Offset
|
# Get Tensor Offset
|
||||||
offset_tensor = self._get(offs, np.uint64)
|
offset_tensor = self._get(offs, np.uint64)
|
||||||
offs += int(offset_tensor.nbytes)
|
offs = self.data.tell()
|
||||||
|
|
||||||
return ReaderField(
|
return ReaderField(
|
||||||
orig_offs,
|
orig_offs,
|
||||||
|
@ -237,9 +262,9 @@ class GGUFReader:
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
orig_offs = offs
|
orig_offs = offs
|
||||||
kv_klen, kv_kdata = self._get_str(offs)
|
kv_klen, kv_kdata = self._get_str(offs)
|
||||||
offs += int(kv_klen.nbytes + kv_kdata.nbytes)
|
offs = self.data.tell()
|
||||||
raw_kv_type = self._get(offs, np.uint32)
|
raw_kv_type = self._get(offs, np.uint32)
|
||||||
offs += int(raw_kv_type.nbytes)
|
offs = self.data.tell()
|
||||||
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
|
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
|
||||||
idxs_offs = len(parts)
|
idxs_offs = len(parts)
|
||||||
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
|
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
|
||||||
|
@ -258,7 +283,7 @@ class GGUFReader:
|
||||||
tensor_fields = []
|
tensor_fields = []
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
field = self._get_tensor_info_field(offs)
|
field = self._get_tensor_info_field(offs)
|
||||||
offs += sum(int(part.nbytes) for part in field.parts)
|
offs = self.data.tell()
|
||||||
tensor_fields.append(field)
|
tensor_fields.append(field)
|
||||||
return offs, tensor_fields
|
return offs, tensor_fields
|
||||||
|
|
||||||
|
@ -311,7 +336,7 @@ class GGUFReader:
|
||||||
n_elements = n_elems,
|
n_elements = n_elems,
|
||||||
n_bytes = n_bytes,
|
n_bytes = n_bytes,
|
||||||
data_offset = data_offs,
|
data_offset = data_offs,
|
||||||
data = self._get(data_offs, item_type, item_count).reshape(np_dims),
|
data = self._get(data_offs, item_type, item_count, use_mmap=True).reshape(np_dims),
|
||||||
field = field,
|
field = field,
|
||||||
))
|
))
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue